avoid allreduce for dense model in pp

pull/570/head
Wenwen Qu 2024-01-04 13:35:56 +08:00
parent 5539f9db50
commit 4e6db4af0f
1 changed files with 8 additions and 3 deletions

View File

@ -445,6 +445,8 @@ class PipelineScheduler(BaseScheduler):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
if hasattr(gpc.config.model, "num_experts"):
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None:
@ -647,6 +649,8 @@ class PipelineScheduler(BaseScheduler):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
if hasattr(gpc.config.model, "num_experts"):
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if accum_loss is not None:
@ -1387,6 +1391,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
output, label = (None, None)
if hasattr(gpc.config.model, "num_experts"):
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss