mirror of https://github.com/InternLM/InternLM
avoid allreduce for dense model in pp
parent
5539f9db50
commit
4e6db4af0f
|
@ -445,7 +445,9 @@ class PipelineScheduler(BaseScheduler):
|
||||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
|
||||||
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss += accum_moe_loss
|
accum_loss += accum_moe_loss
|
||||||
|
@ -647,7 +649,9 @@ class PipelineScheduler(BaseScheduler):
|
||||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
|
||||||
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
|
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss += accum_moe_loss
|
accum_loss += accum_moe_loss
|
||||||
|
@ -1387,7 +1391,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
else:
|
else:
|
||||||
output, label = (None, None)
|
output, label = (None, None)
|
||||||
|
|
||||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||||
accum_moe_loss = self._accum_moe_loss
|
accum_moe_loss = self._accum_moe_loss
|
||||||
|
|
||||||
accum_loss = self._accum_loss
|
accum_loss = self._accum_loss
|
||||||
|
|
Loading…
Reference in New Issue