avoid allreduce when num_expert=1

pull/570/head
Wenwen Qu 2024-01-04 15:10:21 +08:00
parent 4e6db4af0f
commit 184b5bff39
2 changed files with 6 additions and 6 deletions

View File

@ -122,7 +122,7 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss) self._call_hooks("after_criterion", loss)
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= scale_loss moe_loss /= scale_loss

View File

@ -308,7 +308,7 @@ class PipelineScheduler(BaseScheduler):
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
@ -446,7 +446,7 @@ class PipelineScheduler(BaseScheduler):
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
if hasattr(gpc.config.model, "num_experts"): if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None: if accum_loss is not None:
@ -650,7 +650,7 @@ class PipelineScheduler(BaseScheduler):
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
if hasattr(gpc.config.model, "num_experts"): if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None: if accum_loss is not None:
@ -859,7 +859,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
@ -1391,7 +1391,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else: else:
output, label = (None, None) output, label = (None, None)
if hasattr(gpc.config.model, "num_experts"): if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss accum_moe_loss = self._accum_moe_loss