diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 1c5dd939d..8246275eb 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -15,16 +15,19 @@ from tests.test_autochunk.evoformer.evoformer import evoformer_base def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - torch.cuda.reset_peak_memory_stats() - now_mem = torch.cuda.memory_allocated() / 1024**2 - with torch.no_grad(): - gm(node.clone(), pair.clone()) - new_now_mem = torch.cuda.memory_allocated() / 1024**2 - new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print( - "autochunk now mem:%.2f max mem:%.2f" - % (new_now_mem - now_mem, new_max_mem - now_mem) - ) + # for memory test + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # gm(node1, pair1) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print( + # "autochunk now mem:%.2f max mem:%.2f" + # % (new_now_mem - now_mem, new_max_mem - now_mem) + # ) # test forward with torch.no_grad(): @@ -43,7 +46,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): ) -def _run_offload_codegen(rank): +def _test_autochunk_codegen(rank): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly colossalai.launch( config={}, @@ -56,8 +59,10 @@ def _run_offload_codegen(rank): # build model and input model = evoformer_base().cuda() - node = torch.randn(1, 100, 300, 256).cuda() - pair = torch.randn(1, 300, 300, 128).cuda() + msa_len = 32 + pair_len = 64 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() # trace the module and replace codegen graph = ColoTracer().trace( @@ -85,17 +90,18 @@ def _run_offload_codegen(rank): gm = ColoGraphModule(model, graph) gm.recompile() - # assert we have all the components - # code = graph.python_code("self").src + # assert we have inserted chunk + code = graph.python_code("self").src + assert "chunk_size" in code # print(code) _test_fwd(model, gm, node, pair) gpc.destroy() -def test_autochunk(): - mp.spawn(_run_offload_codegen, nprocs=1) +def test_autochunk_codegen(): + mp.spawn(_test_autochunk_codegen, nprocs=1) if __name__ == "__main__": - _run_offload_codegen(0) + _test_autochunk_codegen(0)