From 35a7b636b3d6252ef0bfc8160fcd69c2d1ddea27 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 9 Sep 2024 05:41:39 +0000 Subject: [PATCH] [fix] fix mem assertation --- tests/kit/model_zoo/transformers/__init__.py | 3 +-- .../test_schedule/test_zerobubble_pp.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 029968231..4adc38619 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,8 +2,7 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * - -# from .chatglm2 import * +from .chatglm2 import * from .command import * from .deepseek import * from .falcon import * diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 9348e4deb..f3093fef0 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -611,20 +611,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config): optimizer_pp.step() - torch.cuda.memory_allocated() / 1024**3 + after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 # assert memory if rank != 0: # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 # output hid_dim * hid_dim * 4(fp32) / 1024**3 - # assert (after_pp_step_memory - after_init_memory) == (in_dim * in_dim * 4 * 3 / 1024**3) - pass + print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} == {(in_dim * in_dim * 4 * 3 / 1024**3)}") + assert (after_pp_step_memory - after_init_memory) == (in_dim * in_dim * 4 * 3 / 1024**3) + # pass else: # rank0 will also hold output; - # assert round((after_pp_step_memory - after_init_memory), 5) == round( - # (in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 - # ) - pass + print( + f"rank {rank}: {(after_pp_step_memory - after_init_memory)} == {(in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3)}" + ) + assert round((after_pp_step_memory - after_init_memory), 5) == round( + (in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + ) + # pass ########################## # Fwd bwd for base ##########################