|
|
|
@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self.model.backward(loss)
|
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
|
|
|
|
# This function is called except the last stage of pipeline parallel
|
|
|
|
|
# It receives the scaled grad from the previous rank
|
|
|
|
|
# No need to scale the grad again
|
|
|
|
|
# Need to unscale when optimizing
|
|
|
|
|
self.optim_state = OptimState.SCALED
|
|
|
|
|
self.model.backward_by_grad(tensor, grad)
|
|
|
|
|
|
|
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
|
|
|
|
if self.optim_state == OptimState.SCALED:
|
|
|
|
|
self._prepare_grads()
|
|
|
|
|
self._unscale_grads()
|
|
|
|
|
return super().clip_grad_norm(model, max_norm)
|
|
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
|
|
self._prepare_grads()
|
|
|
|
|
self._maybe_move_fp32_shards()
|
|
|
|
|
|
|
|
|
|
# unscale grads if scaled
|
|
|
|
|
if self.optim_state == OptimState.SCALED:
|
|
|
|
|
self._prepare_grads()
|
|
|
|
|
self._unscale_grads()
|
|
|
|
|
|
|
|
|
|
self._maybe_move_fp32_shards()
|
|
|
|
|
found_inf = self._check_overflow()
|
|
|
|
|
self.grad_scaler.update(found_inf)
|
|
|
|
|
|
|
|
|
|