refactor code

pull/182/head
zhanglei 2023-08-22 10:53:21 +08:00
parent ac243e5b33
commit 8407c203a3
1 changed files with 16 additions and 6 deletions

View File

@ -240,8 +240,16 @@ class PipelineScheduler(BaseScheduler):
""" """
return step_id return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, def _forward_step(
accum_loss=None, accum_moe_loss=None, moe_loss_coeff=1.0): self,
engine,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None,
accum_moe_loss=None,
moe_loss_coeff=1.0,
):
""" """
Forward step for passed-in model. If it is the first stage, the input tensor Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -620,8 +628,9 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss, accum_moe_loss return output, label, accum_loss, accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, def forward_backward_step(
return_output_label=True, moe_loss_coeff=1.0): self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -1286,8 +1295,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# 3. Cooldown # 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, def forward_backward_step(
return_output_label=True, moe_loss_coeff=1.0): self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.