|
|
|
@ -142,6 +142,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float): |
|
|
|
|
if self.optim_state == OptimState.SCALED: |
|
|
|
|
self._unscale_grads() |
|
|
|
|
# TODO(ver217): fix zero clip grad norm |
|
|
|
|
return super().clip_grad_norm(model, max_norm) |
|
|
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor): |
|
|
|
@ -150,6 +151,11 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
self.module.backward(loss) |
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): |
|
|
|
|
# This function is called except the last stage of pipeline parallel |
|
|
|
|
# It receives the scaled grad from the previous rank |
|
|
|
|
# No need to scale the grad again |
|
|
|
|
# Need to unscale when optimizing |
|
|
|
|
self.optim_state = OptimState.SCALED |
|
|
|
|
self.module.backward_by_grad(tensor, grad) |
|
|
|
|
|
|
|
|
|
def _maybe_move_fp32_params(self): |
|
|
|
@ -184,7 +190,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
if isinstance(val, torch.Tensor): |
|
|
|
|
self.chunk_manager.add_extern_static_tensor(val) |
|
|
|
|
|
|
|
|
|
def state_dict(self): |
|
|
|
|
optim_state_dict = super().state_dict() |
|
|
|
|
scaler_state_dict = self.grad_scaler.state_dict() |
|
|
|
|
optim_state_dict['scaler'] = scaler_state_dict |
|
|
|
|
return optim_state_dict |
|
|
|
|
|
|
|
|
|
def load_state_dict(self, *args, **kwargs): |
|
|
|
|
if 'scaler' not in args[0]: |
|
|
|
|
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) |
|
|
|
|
else: |
|
|
|
|
scaler_state_dict = args[0].pop('scaler') |
|
|
|
|
self.grad_scaler.load_state_dict(scaler_state_dict) |
|
|
|
|
super().load_state_dict(*args, **kwargs) |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|