mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
54a34a7e46
commit
d9ca2f898d
|
@ -438,7 +438,7 @@ class MemoryEstimator(object):
|
||||||
def _get_delete_node_size(self, user, user_to_last_uses):
|
def _get_delete_node_size(self, user, user_to_last_uses):
|
||||||
return self._get_delete_node(user, user_to_last_uses)[0]
|
return self._get_delete_node(user, user_to_last_uses)[0]
|
||||||
|
|
||||||
def _remove_active_node(self, user, user_to_last_uses, active_list):
|
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
|
||||||
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||||
for i in delete_node:
|
for i in delete_node:
|
||||||
active_list.remove(i)
|
active_list.remove(i)
|
||||||
|
@ -481,48 +481,6 @@ class MemoryEstimator(object):
|
||||||
|
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
def estimate_inference_mem(self, gm: torch.fx.GraphModule):
|
|
||||||
act_memory = 0.0
|
|
||||||
act_memory_peak_log = []
|
|
||||||
act_memory_after_node_log = []
|
|
||||||
not_contiguous_list = []
|
|
||||||
active_node_list = []
|
|
||||||
active_node_list_log = []
|
|
||||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
|
||||||
_delete_free_var_from_last_use(user_to_last_uses)
|
|
||||||
for node in gm.graph.nodes:
|
|
||||||
# if node is placeholder, just add the size of the node
|
|
||||||
if node.op == 'placeholder':
|
|
||||||
act_memory += self._get_meta_node_size(node) / (1024 ** 2)
|
|
||||||
act_memory_peak_log.append(act_memory)
|
|
||||||
active_node_list.append(node.name)
|
|
||||||
# skip output
|
|
||||||
elif node.op == 'output':
|
|
||||||
continue
|
|
||||||
# node is an operation, calculate tmp, output node and delete node memory
|
|
||||||
else:
|
|
||||||
# forward memory
|
|
||||||
act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
|
|
||||||
act_memory += self._get_output_node_size(node) / (1024 ** 2)
|
|
||||||
# record max act memory
|
|
||||||
act_memory_peak_log.append(act_memory)
|
|
||||||
# delete useless memory
|
|
||||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
|
||||||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
|
||||||
# log active node
|
|
||||||
self._add_active_node(node, active_node_list)
|
|
||||||
self._remove_active_node(node, user_to_last_uses, active_node_list)
|
|
||||||
|
|
||||||
act_memory_after_node_log.append(act_memory)
|
|
||||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
|
||||||
print("no chunk")
|
|
||||||
self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
|
||||||
self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
|
||||||
|
|
||||||
param_memory = parameter_size(gm)
|
|
||||||
return act_memory + param_memory, param_memory
|
|
||||||
|
|
||||||
|
|
||||||
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
||||||
shape = node.meta['tensor_meta'].shape
|
shape = node.meta['tensor_meta'].shape
|
||||||
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
||||||
|
@ -550,25 +508,28 @@ class MemoryEstimator(object):
|
||||||
print("")
|
print("")
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
|
def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None):
|
||||||
def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
|
||||||
act_memory = 0.0
|
act_memory = 0.0
|
||||||
act_memory_peak_log = []
|
act_memory_peak_log = []
|
||||||
act_memory_after_node_log = []
|
act_memory_after_node_log = []
|
||||||
|
active_node_list = []
|
||||||
|
active_node_list_log = []
|
||||||
not_contiguous_list = []
|
not_contiguous_list = []
|
||||||
|
node_list = list(gm.graph.nodes)
|
||||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
||||||
_delete_free_var_from_last_use(user_to_last_uses)
|
_delete_free_var_from_last_use(user_to_last_uses)
|
||||||
within_chunk = False
|
|
||||||
region_idx = 0
|
use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes])
|
||||||
|
chunk_within = False
|
||||||
|
chunk_region_idx = 0
|
||||||
chunk_ratio = 1 # use it to estimate chunk mem
|
chunk_ratio = 1 # use it to estimate chunk mem
|
||||||
node_list = list(gm.graph.nodes)
|
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||||
if idx in start_nodes:
|
if use_chunk and idx in start_nodes:
|
||||||
within_chunk = True
|
chunk_within = True
|
||||||
chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
chunk_ratio = self._get_chunk_ratio(node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx])
|
||||||
act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
|
act_memory += self._get_output_node_size(node_list[end_nodes[chunk_region_idx]]) / (1024 ** 2)
|
||||||
|
|
||||||
# if node is placeholder, just add the size of the node
|
# if node is placeholder, just add the size of the node
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
|
@ -586,22 +547,28 @@ class MemoryEstimator(object):
|
||||||
act_memory_peak_log.append(act_memory)
|
act_memory_peak_log.append(act_memory)
|
||||||
# delete useless memory
|
# delete useless memory
|
||||||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
||||||
if within_chunk:
|
if chunk_within:
|
||||||
act_memory -= self._get_chunk_delete_node_size(
|
act_memory -= self._get_chunk_delete_node_size(
|
||||||
node, user_to_last_uses, chunk_ratio, node_list,
|
node, user_to_last_uses, chunk_ratio, node_list,
|
||||||
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
|
start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2)
|
||||||
else:
|
else:
|
||||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||||
|
|
||||||
if idx in end_nodes:
|
# log active node
|
||||||
|
self._add_active_node(node, active_node_list)
|
||||||
|
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||||
|
|
||||||
|
# if node in chunk end nodes, restore chunk settings
|
||||||
|
if use_chunk and idx in end_nodes:
|
||||||
act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||||
within_chunk = False
|
chunk_within = False
|
||||||
chunk_ratio = 1
|
chunk_ratio = 1
|
||||||
region_idx += 1
|
chunk_region_idx += 1
|
||||||
|
|
||||||
act_memory_after_node_log.append(act_memory)
|
act_memory_after_node_log.append(act_memory)
|
||||||
|
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||||
|
|
||||||
print("chunk")
|
print("with chunk" if use_chunk else "without chunk")
|
||||||
self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||||
|
|
||||||
|
@ -725,7 +692,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
|
|
||||||
memory_estimator = MemoryEstimator()
|
memory_estimator = MemoryEstimator()
|
||||||
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||||
memory_estimator.estimate_inference_mem(meta_graph)
|
memory_estimator.estimate_chunk_inference_mem(meta_graph)
|
||||||
|
|
||||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||||
node_index_tracer.trace_node_idx()
|
node_index_tracer.trace_node_idx()
|
||||||
|
|
Loading…
Reference in New Issue