finish memory estimation

pull/2364/head
oahzxl 2022-11-11 15:43:03 +08:00
parent 22f9c60b6b
commit d7634af5c0
2 changed files with 80 additions and 47 deletions

View File

@ -55,15 +55,49 @@ def _get_last_usr(nodes):
return user_to_last_uses
def _delete_free_var_from_last_use(user_to_last_uses):
for key, value in user_to_last_uses.items():
for n in value:
if n.op == 'placeholder':
user_to_last_uses[key].remove(n)
def _get_contiguous_memory(node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ['transpose', 'permute']
if node.op == 'call_function' and 'matmul' in node.name:
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += _get_output_node_size(n)
elif node.op == 'call_module':
for n in node.args:
if n in not_contiguous_list:
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
elif any(i in node.args for i in not_contiguous_list):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
def _estimate_inference_mem(gm: torch.fx.GraphModule):
act_memory = 0
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
for node in gm.graph.nodes:
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node)
act_memory += _get_meta_node_size(node) / (1024 ** 2)
act_memory_peak_log.append(act_memory)
act_memory_after_node_log.append(act_memory)
# skip output
@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
# node is an operation, calculate tmp, output node and delete node memory
else:
# forward memory
act_memory += calculate_fwd_tmp(node)
# act_memory += calculate_fwd_out(node)
act_memory += _get_output_node_size(node)
act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2)
act_memory += _get_output_node_size(node) / (1024 ** 2)
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= calculate_fwd_tmp(node)
act_memory -= _get_delete_node_size(node, user_to_last_uses)
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2)
act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
print("no chunk")
_print_mem_log(act_memory_peak_log, "peak")
_print_mem_log(act_memory_after_node_log, "after")
_print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak")
_print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after")
param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
return act_memory + param_memory, param_memory
def _get_chunk_ratio(node, chunk_dim, chunk_size):
@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
return delete_size
def _print_mem_log(log, title=None):
def _print_mem_log(log, nodes, title=None):
if title:
print("%-8s" % title, end=' ')
for i in log:
print("%.2f " % i, end='')
print("")
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
print("%s:%.2f \t" % (n.name, l), end='')
if (idx + 1) % 3 == 0:
print("")
print("\n")
def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes):
act_memory = 0
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
not_contiguous_list = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
_delete_free_var_from_last_use(user_to_last_uses)
within_chunk = False
region_idx = 0
chunk_ratio = 1 # use it to estimate chunk mem
@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
if idx in start_nodes:
within_chunk = True
chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx])
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]])
act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2)
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node) * chunk_ratio
act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
act_memory_peak_log.append(act_memory)
# skip output
elif node.op == 'output':
@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# node is an operation, calculate tmp, output node and delete node memory
else:
# forward memory
act_memory += calculate_fwd_tmp(node) * chunk_ratio
# act_memory += calculate_fwd_out(node)
act_memory += _get_output_node_size(node) * chunk_ratio
act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2)
act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= calculate_fwd_tmp(node) * chunk_ratio
act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
if within_chunk:
act_memory -= _get_chunk_delete_node_size(
node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx])
node, user_to_last_uses, chunk_ratio, node_list,
start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2)
else:
act_memory -= _get_delete_node_size(node, user_to_last_uses)
act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
if idx in end_nodes:
act_memory -= _get_output_node_size(node) * chunk_ratio
act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2)
within_chunk = False
chunk_ratio = 1
region_idx += 1
act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
print("chunk")
_print_mem_log(act_memory_peak_log, "peak")
_print_mem_log(act_memory_after_node_log, "after")
_print_mem_log(act_memory_peak_log, node_list, "peak")
_print_mem_log(act_memory_after_node_log, node_list, "after")
param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
return act_memory + param_memory, param_memory
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
# find the offload regions
chunk_regions = [(2, 6)]
chunk_regions = [(58, 62)]
chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = []
@ -683,7 +714,9 @@ if CODEGEN_AVAILABLE:
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
_delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]):
"""

View File

@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2))
# with torch.no_grad():
# fx_out = gm(node, pair)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem))
now_mem = torch.cuda.memory_allocated() / 1024**2
with torch.no_grad():
node0 = node.clone()
pair0 = pair.clone()
node1, pair1 = gm(node0, pair0)
new_now_mem = torch.cuda.memory_allocated() / 1024**2
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
# test forward
with torch.no_grad():
@ -63,8 +63,8 @@ def _run_offload_codegen(rank):
# build model and input
model = evoformer_base().cuda()
node = torch.randn(1, 16, 32, 256).cuda()
pair = torch.randn(1, 32, 32, 128).cuda()
node = torch.randn(1, 100, 300, 256).cuda()
pair = torch.randn(1, 300, 300, 128).cuda()
# trace the module and replace codegen
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})