mirror of https://github.com/InternLM/InternLM
avoid allreduce when num_expert=1
parent
4e6db4af0f
commit
184b5bff39
|
@ -122,7 +122,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= scale_loss
|
||||
|
|
|
@ -308,7 +308,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
|
@ -446,7 +446,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
if accum_loss is not None:
|
||||
|
@ -650,7 +650,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
if accum_loss is not None:
|
||||
|
@ -859,7 +859,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if hasattr(gpc.config.model, "num_experts")
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
|
@ -1391,7 +1391,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
output, label = (None, None)
|
||||
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
|
||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
accum_moe_loss = self._accum_moe_loss
|
||||
|
||||
|
|
Loading…
Reference in New Issue