diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index ae4653d65..3a3b3c599 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,15 +32,25 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node0 = node.clone() + # pair0 = pair.clone() + # model.graph(node0, pair0, now_mem) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("\ncode now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) + + torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 with torch.no_grad(): - node0 = node.clone() - pair0 = pair.clone() - node1, pair1 = gm(node0, pair0) + node1 = node.clone() + pair1 = pair.clone() + gm(node1, pair1) new_now_mem = torch.cuda.memory_allocated() / 1024**2 new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) - + print("gm now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) + # test forward with torch.no_grad(): non_fx_out = model(node, pair)