mirror of https://github.com/hpcaitech/ColossalAI
finish basic inference memory estimation
parent
d95cfe2622
commit
12301dd2e9
|
@ -64,6 +64,8 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
|||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
|
@ -81,6 +83,15 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
|||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
|
||||
|
||||
# for i in act_memory_peak_log:
|
||||
# print("%.2f " % i, end='')
|
||||
# print("\n")
|
||||
# for i in act_memory_after_node_log:
|
||||
# print("%.2f " % i, end='')
|
||||
# print("\n")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||
|
||||
|
|
|
@ -32,9 +32,19 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
|||
|
||||
|
||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2))
|
||||
# with torch.no_grad():
|
||||
# fx_out = gm(node, pair)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem))
|
||||
|
||||
# test forward
|
||||
non_fx_out = model(node.clone(), pair.clone())
|
||||
fx_out = gm(node.clone(), pair.clone())
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
|
||||
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
|
||||
|
||||
|
|
Loading…
Reference in New Issue