support zero_overlap_communication

pull/182/head
zhanglei 2023-08-15 16:18:20 +08:00
parent 1accc9f08d
commit 8c7d868f01
2 changed files with 13 additions and 8 deletions

View File

@ -75,7 +75,7 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=False,
zero_overlap_communication=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping

View File

@ -272,11 +272,12 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss)
loss_reduced = loss / self.num_microbatches
accum_loss.add_(loss_reduced)
accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
return output_obj, moe_loss
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
@ -313,13 +314,17 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if moe_loss is not None:
moe_loss.backward(retain_graph=True)
if output_obj_grad is None:
engine.backward(output_obj)
if moe_loss is None:
if output_obj_grad is None:
engine.backward(output_obj)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
if output_obj_grad is None:
engine.backward(output_obj + moe_loss)
else:
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj.
input_obj_grad = None