fix(pipeline): avoid allreduce for dense model (#570)

* avoid allreduce for dense model in pp

* avoid allreduce when num_expert=1
pull/580/head
Wenwen Qu 2024-01-09 10:34:22 +08:00 committed by GitHub
parent 5539f9db50
commit 91480c5b63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 6 deletions

View File

@ -122,7 +122,7 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss) self._call_hooks("after_criterion", loss)
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= scale_loss moe_loss /= scale_loss

View File

@ -308,7 +308,7 @@ class PipelineScheduler(BaseScheduler):
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
@ -445,7 +445,9 @@ class PipelineScheduler(BaseScheduler):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None: if accum_loss is not None:
accum_loss += accum_moe_loss accum_loss += accum_moe_loss
@ -647,7 +649,9 @@ class PipelineScheduler(BaseScheduler):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None: if accum_loss is not None:
accum_loss += accum_moe_loss accum_loss += accum_moe_loss
@ -855,7 +859,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
moe_loss = ( moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
@ -1387,7 +1391,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else: else:
output, label = (None, None) output, label = (None, None)
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss accum_moe_loss = self._accum_moe_loss
accum_loss = self._accum_loss accum_loss = self._accum_loss