update active log

pull/2364/head
oahzxl 2022-11-15 11:30:43 +08:00
parent fad3b6d1a6
commit 54a34a7e46
1 changed files with 43 additions and 13 deletions

View File

@ -407,18 +407,41 @@ class MemoryEstimator(object):
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
return x
def _get_output_node_size(self, n):
def _get_output_node(self, n):
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(fwd_out)
out_size = activation_size(fwd_out)
out_node = [n.name] if out_size > 0 else []
return out_size, out_node
def _get_output_node_size(self, n):
return self._get_output_node(n)[0]
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
for i in new_active:
if i not in active_list:
active_list.append(i)
def _get_delete_node(self, user, user_to_last_uses):
delete_size = 0
delete_node = []
if user.op not in ('placeholder', 'output'):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
out_node = [self._get_output_node(i) for i in nodes_to_delete]
delete_size = sum([i[0] for i in out_node])
for i in range(len(out_node)):
if out_node[i][0] > 0:
delete_node.append(out_node[i][1][0])
return delete_size, delete_node
def _get_delete_node_size(self, user, user_to_last_uses):
if user.op in ('placeholder', 'output'):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete])
return delete_size
return 0
return self._get_delete_node(user, user_to_last_uses)[0]
def _remove_active_node(self, user, user_to_last_uses, active_list):
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
for i in delete_node:
active_list.remove(i)
def _get_last_usr(self, nodes):
node_to_last_use: Dict[Node, Node] = {}
@ -438,7 +461,7 @@ class MemoryEstimator(object):
mem = 0
not_contiguous_ops = ['transpose', 'permute']
if node.op == 'call_function' and 'matmul' in node.name:
if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
@ -463,6 +486,8 @@ class MemoryEstimator(object):
act_memory_peak_log = []
act_memory_after_node_log = []
not_contiguous_list = []
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
for node in gm.graph.nodes:
@ -470,7 +495,7 @@ class MemoryEstimator(object):
if node.op == 'placeholder':
act_memory += self._get_meta_node_size(node) / (1024 ** 2)
act_memory_peak_log.append(act_memory)
act_memory_after_node_log.append(act_memory)
active_node_list.append(node.name)
# skip output
elif node.op == 'output':
continue
@ -484,8 +509,12 @@ class MemoryEstimator(object):
# delete useless memory
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
act_memory_after_node_log.append(act_memory)
# log active node
self._add_active_node(node, active_node_list)
self._remove_active_node(node, user_to_last_uses, active_node_list)
act_memory_after_node_log.append(act_memory)
active_node_list_log.append(copy.deepcopy(active_node_list))
print("no chunk")
self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
@ -551,7 +580,6 @@ class MemoryEstimator(object):
# node is an operation, calculate tmp, output node and delete node memory
else:
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
# record max act memory
@ -694,9 +722,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region = False
node_list = list(nodes)
memory_estimator = MemoryEstimator()
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
memory_estimator.estimate_inference_mem(meta_graph)
node_index_tracer = NodeIndexTracer(meta_graph)
node_index_tracer.trace_node_idx()