mirror of https://github.com/hpcaitech/ColossalAI
update active log
parent
fad3b6d1a6
commit
54a34a7e46
|
@ -407,18 +407,41 @@ class MemoryEstimator(object):
|
|||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
|
||||
def _get_output_node_size(self, n):
|
||||
def _get_output_node(self, n):
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
return activation_size(fwd_out)
|
||||
out_size = activation_size(fwd_out)
|
||||
out_node = [n.name] if out_size > 0 else []
|
||||
return out_size, out_node
|
||||
|
||||
def _get_output_node_size(self, n):
|
||||
return self._get_output_node(n)[0]
|
||||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
for i in new_active:
|
||||
if i not in active_list:
|
||||
active_list.append(i)
|
||||
|
||||
def _get_delete_node(self, user, user_to_last_uses):
|
||||
delete_size = 0
|
||||
delete_node = []
|
||||
if user.op not in ('placeholder', 'output'):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
for i in range(len(out_node)):
|
||||
if out_node[i][0] > 0:
|
||||
delete_node.append(out_node[i][1][0])
|
||||
return delete_size, delete_node
|
||||
|
||||
def _get_delete_node_size(self, user, user_to_last_uses):
|
||||
if user.op in ('placeholder', 'output'):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete])
|
||||
return delete_size
|
||||
return 0
|
||||
return self._get_delete_node(user, user_to_last_uses)[0]
|
||||
|
||||
def _remove_active_node(self, user, user_to_last_uses, active_list):
|
||||
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||
for i in delete_node:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_last_usr(self, nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
|
@ -438,7 +461,7 @@ class MemoryEstimator(object):
|
|||
mem = 0
|
||||
not_contiguous_ops = ['transpose', 'permute']
|
||||
|
||||
if node.op == 'call_function' and 'matmul' in node.name:
|
||||
if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']):
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
|
@ -463,6 +486,8 @@ class MemoryEstimator(object):
|
|||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
not_contiguous_list = []
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
for node in gm.graph.nodes:
|
||||
|
@ -470,7 +495,7 @@ class MemoryEstimator(object):
|
|||
if node.op == 'placeholder':
|
||||
act_memory += self._get_meta_node_size(node) / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
active_node_list.append(node.name)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
|
@ -484,8 +509,12 @@ class MemoryEstimator(object):
|
|||
# delete useless memory
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
# log active node
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_active_node(node, user_to_last_uses, active_node_list)
|
||||
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
print("no chunk")
|
||||
self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
||||
self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
||||
|
@ -551,7 +580,6 @@ class MemoryEstimator(object):
|
|||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
# TODO: permute will create a tmp copy if not contiguous
|
||||
act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
# record max act memory
|
||||
|
@ -694,9 +722,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
within_chunk_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
memory_estimator = MemoryEstimator()
|
||||
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||
memory_estimator.estimate_inference_mem(meta_graph)
|
||||
|
||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||
node_index_tracer.trace_node_idx()
|
||||
|
||||
|
|
Loading…
Reference in New Issue