mirror of https://github.com/hpcaitech/ColossalAI
finish memory estimation
parent
22f9c60b6b
commit
d7634af5c0
107
chunk_codegen.py
107
chunk_codegen.py
|
@ -55,15 +55,49 @@ def _get_last_usr(nodes):
|
|||
return user_to_last_uses
|
||||
|
||||
|
||||
def _delete_free_var_from_last_use(user_to_last_uses):
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == 'placeholder':
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
|
||||
mem = 0
|
||||
not_contiguous_ops = ['transpose', 'permute']
|
||||
|
||||
if node.op == 'call_function' and 'matmul' in node.name:
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
mem += _get_output_node_size(n)
|
||||
elif node.op == 'call_module':
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# module will just make origin tensor to contiguous
|
||||
if delete:
|
||||
not_contiguous_list.remove(n)
|
||||
elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
|
||||
if node not in not_contiguous_list:
|
||||
not_contiguous_list.append(node)
|
||||
elif any(i in node.args for i in not_contiguous_list):
|
||||
if node not in not_contiguous_list:
|
||||
not_contiguous_list.append(node)
|
||||
|
||||
return mem
|
||||
|
||||
|
||||
def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
||||
act_memory = 0
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
for node in gm.graph.nodes:
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node)
|
||||
act_memory += _get_meta_node_size(node) / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
# skip output
|
||||
|
@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
|||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
act_memory += calculate_fwd_tmp(node)
|
||||
# act_memory += calculate_fwd_out(node)
|
||||
act_memory += _get_output_node_size(node)
|
||||
act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
|
||||
act_memory += _get_output_node_size(node) / (1024 ** 2)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= calculate_fwd_tmp(node)
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses)
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
||||
|
||||
print("no chunk")
|
||||
_print_mem_log(act_memory_peak_log, "peak")
|
||||
_print_mem_log(act_memory_after_node_log, "after")
|
||||
_print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
|
||||
_print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||
return act_memory + param_memory, param_memory
|
||||
|
||||
|
||||
def _get_chunk_ratio(node, chunk_dim, chunk_size):
|
||||
|
@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
|
|||
return delete_size
|
||||
|
||||
|
||||
def _print_mem_log(log, title=None):
|
||||
def _print_mem_log(log, nodes, title=None):
|
||||
if title:
|
||||
print("%-8s" % title, end=' ')
|
||||
for i in log:
|
||||
print("%.2f " % i, end='')
|
||||
print("")
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
print("%s:%.2f \t" % (n.name, l), end='')
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
|
||||
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
|
||||
act_memory = 0
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
within_chunk = False
|
||||
region_idx = 0
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
|
@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
if idx in start_nodes:
|
||||
within_chunk = True
|
||||
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
|
||||
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]])
|
||||
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node) * chunk_ratio
|
||||
act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
|
@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
|
|||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
act_memory += calculate_fwd_tmp(node) * chunk_ratio
|
||||
# act_memory += calculate_fwd_out(node)
|
||||
act_memory += _get_output_node_size(node) * chunk_ratio
|
||||
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
|
||||
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= calculate_fwd_tmp(node) * chunk_ratio
|
||||
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
||||
if within_chunk:
|
||||
act_memory -= _get_chunk_delete_node_size(
|
||||
node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx])
|
||||
node, user_to_last_uses, chunk_ratio, node_list,
|
||||
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
|
||||
else:
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses)
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
|
||||
if idx in end_nodes:
|
||||
act_memory -= _get_output_node_size(node) * chunk_ratio
|
||||
act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
within_chunk = False
|
||||
chunk_ratio = 1
|
||||
region_idx += 1
|
||||
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
||||
|
||||
print("chunk")
|
||||
_print_mem_log(act_memory_peak_log, "peak")
|
||||
_print_mem_log(act_memory_after_node_log, "after")
|
||||
|
||||
_print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
_print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||
return act_memory + param_memory, param_memory
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
|
@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
"""
|
||||
|
||||
# find the offload regions
|
||||
chunk_regions = [(2, 6)]
|
||||
chunk_regions = [(58, 62)]
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
|
@ -683,7 +714,9 @@ if CODEGEN_AVAILABLE:
|
|||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||
"""
|
||||
|
|
|
@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
|||
|
||||
|
||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2))
|
||||
# with torch.no_grad():
|
||||
# fx_out = gm(node, pair)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem))
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
with torch.no_grad():
|
||||
node0 = node.clone()
|
||||
pair0 = pair.clone()
|
||||
node1, pair1 = gm(node0, pair0)
|
||||
new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
with torch.no_grad():
|
||||
|
@ -63,8 +63,8 @@ def _run_offload_codegen(rank):
|
|||
|
||||
# build model and input
|
||||
model = evoformer_base().cuda()
|
||||
node = torch.randn(1, 16, 32, 256).cuda()
|
||||
pair = torch.randn(1, 32, 32, 128).cuda()
|
||||
node = torch.randn(1, 100, 300, 256).cuda()
|
||||
pair = torch.randn(1, 300, 300, 128).cuda()
|
||||
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})
|
||||
|
|
Loading…
Reference in New Issue