|
|
|
@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self.model.backward(loss) |
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: |
|
|
|
|
# This function is called except the last stage of pipeline parallel |
|
|
|
|
# It receives the scaled grad from the previous rank |
|
|
|
|
# No need to scale the grad again |
|
|
|
|
# Need to unscale when optimizing |
|
|
|
|
self.optim_state = OptimState.SCALED |
|
|
|
|
self.model.backward_by_grad(tensor, grad) |
|
|
|
|
|
|
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float): |
|
|
|
|
if self.optim_state == OptimState.SCALED: |
|
|
|
|
self._prepare_grads() |
|
|
|
|
self._unscale_grads() |
|
|
|
|
return super().clip_grad_norm(model, max_norm) |
|
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs): |
|
|
|
|
self._prepare_grads() |
|
|
|
|
self._maybe_move_fp32_shards() |
|
|
|
|
|
|
|
|
|
# unscale grads if scaled |
|
|
|
|
if self.optim_state == OptimState.SCALED: |
|
|
|
|
self._prepare_grads() |
|
|
|
|
self._unscale_grads() |
|
|
|
|
|
|
|
|
|
self._maybe_move_fp32_shards() |
|
|
|
|
found_inf = self._check_overflow() |
|
|
|
|
self.grad_scaler.update(found_inf) |
|
|
|
|
|
|
|
|
|