mirror of https://github.com/hpcaitech/ColossalAI
finish memory estimation
parent
12301dd2e9
commit
8cca684c56
103
chunk_codegen.py
103
chunk_codegen.py
|
@ -85,25 +85,97 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
||||||
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||||
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
||||||
|
|
||||||
# for i in act_memory_peak_log:
|
print("no chunk")
|
||||||
# print("%.2f " % i, end='')
|
_print_mem_log(act_memory_peak_log, "peak")
|
||||||
# print("\n")
|
_print_mem_log(act_memory_after_node_log, "after")
|
||||||
# for i in act_memory_after_node_log:
|
|
||||||
# print("%.2f " % i, end='')
|
|
||||||
# print("\n")
|
|
||||||
|
|
||||||
param_memory = parameter_size(gm)
|
param_memory = parameter_size(gm)
|
||||||
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||||
|
|
||||||
|
|
||||||
def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size):
|
def _get_chunk_ratio(node, chunk_dim, chunk_size):
|
||||||
node_size = 0
|
shape = node.meta['tensor_meta'].shape
|
||||||
param_size = 0
|
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
||||||
for node in gm.graph.nodes:
|
return chunk_ratio
|
||||||
node_size += calculate_fwd_tmp(node)
|
|
||||||
node_size += calculate_fwd_out(node)
|
|
||||||
param_size = parameter_size(gm)
|
def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node):
|
||||||
return (node_size + param_size) / 1024**2, param_size / 1024**2
|
if user.op in ('placeholder', 'output'):
|
||||||
|
return 0
|
||||||
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
|
delete_size = 0
|
||||||
|
for n in nodes_to_delete:
|
||||||
|
node_idx = _find_idx_by_name(n.name, node_list)
|
||||||
|
if start_node <= node_idx < end_node:
|
||||||
|
delete_size += _get_output_node_size(n) * chunk_ratio
|
||||||
|
return delete_size
|
||||||
|
|
||||||
|
|
||||||
|
def _print_mem_log(log, title=None):
|
||||||
|
if title:
|
||||||
|
print("%-8s" % title, end=' ')
|
||||||
|
for i in log:
|
||||||
|
print("%.2f " % i, end='')
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
||||||
|
act_memory = 0
|
||||||
|
act_memory_peak_log = []
|
||||||
|
act_memory_after_node_log = []
|
||||||
|
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||||
|
within_chunk = False
|
||||||
|
region_idx = 0
|
||||||
|
chunk_ratio = 1 # use it to estimate chunk mem
|
||||||
|
node_list = list(gm.graph.nodes)
|
||||||
|
|
||||||
|
for idx, node in enumerate(node_list):
|
||||||
|
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||||
|
if idx in start_nodes:
|
||||||
|
within_chunk = True
|
||||||
|
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
||||||
|
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]])
|
||||||
|
|
||||||
|
# if node is placeholder, just add the size of the node
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
act_memory += _get_meta_node_size(node) * chunk_ratio
|
||||||
|
act_memory_peak_log.append(act_memory)
|
||||||
|
# skip output
|
||||||
|
elif node.op == 'output':
|
||||||
|
continue
|
||||||
|
# node is an operation, calculate tmp, output node and delete node memory
|
||||||
|
else:
|
||||||
|
# forward memory
|
||||||
|
act_memory += calculate_fwd_tmp(node) * chunk_ratio
|
||||||
|
# act_memory += calculate_fwd_out(node)
|
||||||
|
act_memory += _get_output_node_size(node) * chunk_ratio
|
||||||
|
# record max act memory
|
||||||
|
act_memory_peak_log.append(act_memory)
|
||||||
|
# delete useless memory
|
||||||
|
act_memory -= calculate_fwd_tmp(node) * chunk_ratio
|
||||||
|
if within_chunk:
|
||||||
|
act_memory -= _get_chunk_delete_node_size(
|
||||||
|
node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx])
|
||||||
|
else:
|
||||||
|
act_memory -= _get_delete_node_size(node, user_to_last_uses)
|
||||||
|
|
||||||
|
if idx in end_nodes:
|
||||||
|
act_memory -= _get_output_node_size(node) * chunk_ratio
|
||||||
|
within_chunk = False
|
||||||
|
chunk_ratio = 1
|
||||||
|
region_idx += 1
|
||||||
|
|
||||||
|
act_memory_after_node_log.append(act_memory)
|
||||||
|
|
||||||
|
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||||
|
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
||||||
|
|
||||||
|
print("chunk")
|
||||||
|
_print_mem_log(act_memory_peak_log, "peak")
|
||||||
|
_print_mem_log(act_memory_after_node_log, "after")
|
||||||
|
|
||||||
|
param_memory = parameter_size(gm)
|
||||||
|
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||||
|
|
||||||
|
|
||||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||||
|
@ -444,7 +516,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# find the offload regions
|
# find the offload regions
|
||||||
chunk_regions = [(2, 5)]
|
chunk_regions = [(2, 6)]
|
||||||
chunk_starts = [item[0] for item in chunk_regions]
|
chunk_starts = [item[0] for item in chunk_regions]
|
||||||
chunk_ends = [item[1] for item in chunk_regions]
|
chunk_ends = [item[1] for item in chunk_regions]
|
||||||
chunk_inputs = []
|
chunk_inputs = []
|
||||||
|
@ -452,6 +524,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
_estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||||
_estimate_inference_mem(meta_graph)
|
_estimate_inference_mem(meta_graph)
|
||||||
|
|
||||||
# find the input and output var names for each offload region
|
# find the input and output var names for each offload region
|
||||||
|
|
Loading…
Reference in New Issue