mirror of https://github.com/InternLM/InternLM
support zero_overlap_communication
parent
1accc9f08d
commit
8c7d868f01
|
@ -75,7 +75,7 @@ grad_scaler = dict(
|
||||||
|
|
||||||
hybrid_zero_optimizer = dict(
|
hybrid_zero_optimizer = dict(
|
||||||
# Enable low_level_optimzer overlap_communication
|
# Enable low_level_optimzer overlap_communication
|
||||||
zero_overlap_communication=False,
|
zero_overlap_communication=True,
|
||||||
# bucket size for nccl communication params
|
# bucket size for nccl communication params
|
||||||
reduce_bucket_size=512 * 1024 * 1024,
|
reduce_bucket_size=512 * 1024 * 1024,
|
||||||
# grad clipping
|
# grad clipping
|
||||||
|
|
|
@ -272,11 +272,12 @@ class PipelineScheduler(BaseScheduler):
|
||||||
self._call_hooks("after_criterion", loss)
|
self._call_hooks("after_criterion", loss)
|
||||||
|
|
||||||
loss_reduced = loss / self.num_microbatches
|
loss_reduced = loss / self.num_microbatches
|
||||||
accum_loss.add_(loss_reduced)
|
accum_loss.add_(loss_reduced.detach())
|
||||||
output_obj = loss_reduced
|
output_obj = loss_reduced
|
||||||
|
|
||||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
|
||||||
return output_obj, moe_loss
|
return output_obj, moe_loss
|
||||||
|
|
||||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
|
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
|
||||||
|
@ -313,13 +314,17 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
||||||
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
||||||
if moe_loss is not None:
|
if moe_loss is None:
|
||||||
moe_loss.backward(retain_graph=True)
|
if output_obj_grad is None:
|
||||||
|
engine.backward(output_obj)
|
||||||
if output_obj_grad is None:
|
else:
|
||||||
engine.backward(output_obj)
|
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||||
else:
|
else:
|
||||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
if output_obj_grad is None:
|
||||||
|
engine.backward(output_obj + moe_loss)
|
||||||
|
else:
|
||||||
|
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
|
||||||
|
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
Loading…
Reference in New Issue