mirror of https://github.com/hpcaitech/ColossalAI
[fix] remove duplicate arg; rm comments;
parent
c5503b0d80
commit
bb0390c90d
|
@ -34,7 +34,6 @@ class MlpModel(nn.Module):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
model=None,
|
|
||||||
data: torch.Tensor = None,
|
data: torch.Tensor = None,
|
||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
stage_index=None,
|
stage_index=None,
|
||||||
|
@ -622,10 +621,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
|
||||||
model_pp._forward = model_pp.forward
|
model_pp._forward = model_pp.forward
|
||||||
# model_pp.forward = MethodType(
|
|
||||||
# partial(model_pp._forward, stage_mgr=stage_manager),
|
|
||||||
# model_pp,
|
|
||||||
# )
|
|
||||||
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
|
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
|
||||||
|
|
||||||
# init optimizer
|
# init optimizer
|
||||||
|
|
Loading…
Reference in New Issue