From a8dd77ce76a5be5e12c1106d488f7321e9986e19 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Tue, 22 Aug 2023 10:35:17 +0800 Subject: [PATCH] fix bug on logger --- internlm/core/scheduler/pipeline_scheduler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 1b749e7..8728c01 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -329,7 +329,6 @@ class PipelineScheduler(BaseScheduler): else: engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) - # Collect the grad of the input_obj. input_obj_grad = None if input_obj is not None: @@ -1312,11 +1311,11 @@ class InterleavedPipelineScheduler(PipelineScheduler): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) - if return_loss: - self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + if return_output_label: self._return_tensors = [] - + if forward_only: self._forward_only_step(engine, moe_loss_coeff) else: @@ -1327,14 +1326,15 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: output, label = (None, None) - accum_loss = self._accum_loss - accum_loss += self._accum_moe_loss - logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}") dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) accum_moe_loss = self._accum_moe_loss + accum_loss = self._accum_loss + if return_loss: + accum_loss += self._accum_moe_loss + self._clear_state() return output, label, accum_loss, accum_moe_loss