From 4e6db4af0f0aad6ea2039c61aeaf7af2399ed25e Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 4 Jan 2024 13:35:56 +0800 Subject: [PATCH] avoid allreduce for dense model in pp --- internlm/core/scheduler/pipeline_scheduler.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5b864ff..0398783 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -445,7 +445,9 @@ class PipelineScheduler(BaseScheduler): comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) - dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + + if hasattr(gpc.config.model, "num_experts"): + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) if accum_loss is not None: accum_loss += accum_moe_loss @@ -647,7 +649,9 @@ class PipelineScheduler(BaseScheduler): comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) - dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + + if hasattr(gpc.config.model, "num_experts"): + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) if accum_loss is not None: accum_loss += accum_moe_loss @@ -1387,7 +1391,8 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: output, label = (None, None) - dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if hasattr(gpc.config.model, "num_experts"): + dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) accum_moe_loss = self._accum_moe_loss accum_loss = self._accum_loss