update run

pull/2364/head
oahzxl 2022-12-23 17:32:11 +08:00
parent 51ef8384c1
commit 9b1b890347
1 changed files with 15 additions and 5 deletions

View File

@ -32,15 +32,25 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node0 = node.clone()
# pair0 = pair.clone()
# model.graph(node0, pair0, now_mem)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
with torch.no_grad():
node0 = node.clone()
pair0 = pair.clone()
node1, pair1 = gm(node0, pair0)
node1 = node.clone()
pair1 = pair.clone()
gm(node1, pair1)
new_now_mem = torch.cuda.memory_allocated() / 1024**2
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
print("gm now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
# test forward
with torch.no_grad():
non_fx_out = model(node, pair)