fix bug on logger

pull/182/head
zhanglei 2023-08-22 10:35:17 +08:00
parent 05a3b2a3be
commit a8dd77ce76
1 changed files with 7 additions and 7 deletions

View File

@ -329,7 +329,6 @@ class PipelineScheduler(BaseScheduler):
else:
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
@ -1312,11 +1311,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device())
if return_loss:
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
if return_output_label:
self._return_tensors = []
if forward_only:
self._forward_only_step(engine, moe_loss_coeff)
else:
@ -1327,14 +1326,15 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
output, label = (None, None)
accum_loss = self._accum_loss
accum_loss += self._accum_moe_loss
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss
accum_loss = self._accum_loss
if return_loss:
accum_loss += self._accum_moe_loss
self._clear_state()
return output, label, accum_loss, accum_moe_loss