mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix sharded optim step and clip_grad_norm (#1226)
parent
f071b500b6
commit
a45ddf2d5f
|
@ -195,7 +195,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
# Make sure the grads are in fp32
|
# Make sure the grads are in fp32
|
||||||
assert param.grad.dtype == torch.float, \
|
assert param.grad.dtype == torch.float, \
|
||||||
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
|
||||||
if hasattr(param, 'zero_is_sharded'):
|
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
|
||||||
has_zero_shared_param = True
|
has_zero_shared_param = True
|
||||||
params.append(param)
|
params.append(param)
|
||||||
|
|
||||||
|
@ -234,7 +234,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
if is_model_parallel_parameter(p):
|
if is_model_parallel_parameter(p):
|
||||||
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
|
||||||
tensor_parallel_grads.append(p.grad.data / reductor)
|
tensor_parallel_grads.append(p.grad.data / reductor)
|
||||||
elif hasattr(p, 'zero_is_sharded'):
|
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
|
||||||
zero_sharded_grads.append(p.grad.data)
|
zero_sharded_grads.append(p.grad.data)
|
||||||
else:
|
else:
|
||||||
no_tensor_parallel_grads.append(p.grad.data)
|
no_tensor_parallel_grads.append(p.grad.data)
|
||||||
|
|
|
@ -169,21 +169,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.model.backward(loss)
|
self.model.backward(loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||||
|
# This function is called except the last stage of pipeline parallel
|
||||||
|
# It receives the scaled grad from the previous rank
|
||||||
|
# No need to scale the grad again
|
||||||
|
# Need to unscale when optimizing
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
self.model.backward_by_grad(tensor, grad)
|
self.model.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||||
if self.optim_state == OptimState.SCALED:
|
if self.optim_state == OptimState.SCALED:
|
||||||
|
self._prepare_grads()
|
||||||
self._unscale_grads()
|
self._unscale_grads()
|
||||||
return super().clip_grad_norm(model, max_norm)
|
return super().clip_grad_norm(model, max_norm)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
self._prepare_grads()
|
|
||||||
self._maybe_move_fp32_shards()
|
|
||||||
|
|
||||||
# unscale grads if scaled
|
# unscale grads if scaled
|
||||||
if self.optim_state == OptimState.SCALED:
|
if self.optim_state == OptimState.SCALED:
|
||||||
|
self._prepare_grads()
|
||||||
self._unscale_grads()
|
self._unscale_grads()
|
||||||
|
|
||||||
|
self._maybe_move_fp32_shards()
|
||||||
found_inf = self._check_overflow()
|
found_inf = self._check_overflow()
|
||||||
self.grad_scaler.update(found_inf)
|
self.grad_scaler.update(found_inf)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue