[fix] remove duplicate arg; rm comments;

pull/6069/head
duanjunwen 2024-09-26 09:45:44 +00:00
parent c5503b0d80
commit bb0390c90d
1 changed files with 1 additions and 5 deletions

View File

@ -34,7 +34,6 @@ class MlpModel(nn.Module):
def forward(
self,
model=None,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_index=None,
@ -622,10 +621,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
model_pp._forward = model_pp.forward
# model_pp.forward = MethodType(
# partial(model_pp._forward, stage_mgr=stage_manager),
# model_pp,
# )
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
# init optimizer