Browse Source

[fix] remove duplicate arg; rm comments;

pull/6069/head
duanjunwen 2 months ago
parent
commit
bb0390c90d
  1. 6
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

6
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -34,7 +34,6 @@ class MlpModel(nn.Module):
def forward(
self,
model=None,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_index=None,
@ -622,10 +621,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
model_pp._forward = model_pp.forward
# model_pp.forward = MethodType(
# partial(model_pp._forward, stage_mgr=stage_manager),
# model_pp,
# )
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
# init optimizer

Loading…
Cancel
Save