mirror of https://github.com/InternLM/InternLM
support zero_overlap_communication
parent
1accc9f08d
commit
8c7d868f01
|
@ -75,7 +75,7 @@ grad_scaler = dict(
|
|||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=False,
|
||||
zero_overlap_communication=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
|
|
|
@ -272,11 +272,12 @@ class PipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_criterion", loss)
|
||||
|
||||
loss_reduced = loss / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced)
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
return output_obj, moe_loss
|
||||
|
||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
|
||||
|
@ -313,13 +314,17 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
||||
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
||||
if moe_loss is not None:
|
||||
moe_loss.backward(retain_graph=True)
|
||||
|
||||
if moe_loss is None:
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
else:
|
||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||
else:
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj + moe_loss)
|
||||
else:
|
||||
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
|
||||
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
|
Loading…
Reference in New Issue