From a45ddf2d5f21cd24b5f9e79ef2cbf30dc57fa10a Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 8 Jul 2022 13:34:48 +0800 Subject: [PATCH] [hotfix] fix sharded optim step and clip_grad_norm (#1226) --- colossalai/utils/common.py | 4 ++-- colossalai/zero/sharded_optim/sharded_optim_v2.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 6114ab11a..19748770d 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -195,7 +195,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Make sure the grads are in fp32 assert param.grad.dtype == torch.float, \ f'expected gradient to be dtype torch.float, but got {param.grad.type()}' - if hasattr(param, 'zero_is_sharded'): + if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: has_zero_shared_param = True params.append(param) @@ -234,7 +234,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if is_model_parallel_parameter(p): reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) - elif hasattr(p, 'zero_is_sharded'): + elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: zero_sharded_grads.append(p.grad.data) else: no_tensor_parallel_grads.append(p.grad.data) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 63545f11e..194cc165e 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): if self.optim_state == OptimState.SCALED: + self._prepare_grads() self._unscale_grads() return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): - self._prepare_grads() - self._maybe_move_fp32_shards() # unscale grads if scaled if self.optim_state == OptimState.SCALED: + self._prepare_grads() self._unscale_grads() + self._maybe_move_fp32_shards() found_inf = self._check_overflow() self.grad_scaler.update(found_inf)