diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index bf1fba3c6..b0927c0c4 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -389,7 +389,8 @@ def test_zerobubble_pipeline_base( ########################## # fwd & bwd output_base = model_base(input_base) - loss_base = output_base.mean() + # loss_base = output_base.mean() + loss_base = criterion(output_base) loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")