diff --git a/configs/moe_cfg.py b/configs/moe_cfg.py index 6fc41f6..44d0cb6 100644 --- a/configs/moe_cfg.py +++ b/configs/moe_cfg.py @@ -75,7 +75,7 @@ grad_scaler = dict( hybrid_zero_optimizer = dict( # Enable low_level_optimzer overlap_communication - zero_overlap_communication=False, + zero_overlap_communication=True, # bucket size for nccl communication params reduce_bucket_size=512 * 1024 * 1024, # grad clipping diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5d7dc04..7ba14fc 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -272,11 +272,12 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("after_criterion", loss) loss_reduced = loss / self.num_microbatches - accum_loss.add_(loss_reduced) + accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced moe_loss = sum(moe_losses) * moe_loss_coeff moe_loss /= self.num_microbatches + return output_obj, moe_loss def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None): @@ -313,13 +314,17 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("before_backward", output_obj, output_obj_grad) with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): - if moe_loss is not None: - moe_loss.backward(retain_graph=True) - - if output_obj_grad is None: - engine.backward(output_obj) + if moe_loss is None: + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) else: - engine.backward_by_grad(output_obj, output_obj_grad) + if output_obj_grad is None: + engine.backward(output_obj + moe_loss) + else: + engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) + # Collect the grad of the input_obj. input_obj_grad = None