From bb0390c90d8645b2d58035e82335049c468d36ec Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 09:45:44 +0000 Subject: [PATCH] [fix] remove duplicate arg; rm comments; --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 46bd4a581..0f2d6c49c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -34,7 +34,6 @@ class MlpModel(nn.Module): def forward( self, - model=None, data: torch.Tensor = None, hidden_states: torch.Tensor = None, stage_index=None, @@ -622,10 +621,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) model_pp._forward = model_pp.forward - # model_pp.forward = MethodType( - # partial(model_pp._forward, stage_mgr=stage_manager), - # model_pp, - # ) + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) # init optimizer