[hotfix] fix sharded optim step and clip_grad_norm (#1226)

pull/1237/head
ver217 2 years ago committed by GitHub
parent f071b500b6
commit a45ddf2d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -195,7 +195,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Make sure the grads are in fp32
assert param.grad.dtype == torch.float, \
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
if hasattr(param, 'zero_is_sharded'):
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
has_zero_shared_param = True
params.append(param)
@ -234,7 +234,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif hasattr(p, 'zero_is_sharded'):
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
zero_sharded_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)

@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
self._maybe_move_fp32_shards()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._unscale_grads()
self._maybe_move_fp32_shards()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)

Loading…
Cancel
Save