refactor code

pull/182/head
zhanglei 2023-08-22 10:53:21 +08:00
parent ac243e5b33
commit 8407c203a3
1 changed files with 16 additions and 6 deletions

View File

@ -240,8 +240,16 @@ class PipelineScheduler(BaseScheduler):
"""
return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True,
accum_loss=None, accum_moe_loss=None, moe_loss_coeff=1.0):
def _forward_step(
self,
engine,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None,
accum_moe_loss=None,
moe_loss_coeff=1.0,
):
"""
Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -620,8 +628,9 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss, accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
return_output_label=True, moe_loss_coeff=1.0):
def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -1286,8 +1295,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True,
return_output_label=True, moe_loss_coeff=1.0):
def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.