From 795210dd9940e82d6bf1ac0da65ee58faf0783a1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 3 Mar 2022 15:42:53 +0800 Subject: [PATCH] add fp32 master params in sharded adam --- colossalai/zero/sharded_optim/sharded_adam.py | 48 ++++++------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_adam.py b/colossalai/zero/sharded_optim/sharded_adam.py index 6980b2b71..1cb8c4a1d 100644 --- a/colossalai/zero/sharded_optim/sharded_adam.py +++ b/colossalai/zero/sharded_optim/sharded_adam.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional, Union +from typing import Dict, Optional, Union import torch import torch.distributed as dist @@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.zero.sharded_model import ShardedModelV2 from torch import Tensor from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter from torch.optim import Optimizer from ._utils import has_inf_or_nan @@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer): super().__init__(adam_optim) self.model: Union[nn.Module, ShardedModelV2] = sharded_model self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) - self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') + self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) @@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer): growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale) - self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device) + self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device) + + # Store fp32 params + self.master_params: Dict[Parameter, Tensor] = {} - # Early state initialization for group in adam_optim.param_groups: for p in group['params']: - state_shape = p.shape if hasattr(p, 'ca_attr'): assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' - # TODO: use payload shape - state_shape = p.ca_attr.payload(self.state_device) - state = adam_optim.state[p] - assert len(state) == 0, 'adam optimizer initialized' - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros(state_shape, - memory_format=torch.preserve_format, - dtype=torch.float, - device=self.state_device) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros(state_shape, - memory_format=torch.preserve_format, - dtype=torch.float, - device=self.state_device) - if group['amsgrad']: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros(state_shape, - memory_format=torch.preserve_format, - dtype=torch.float, - device=self.state_device) + self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float) + else: + self.master_params[p] = p.data.to(torch.float) def step(self, *args, **kwargs): # unscale grads if scaled @@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer): self.zero_grad() return - # Write payload back to p.data + # Write master param to p.data for group in self.optim.param_groups: for p in group['params']: - data = p.data - if hasattr(p, 'ca_attr'): - data = p.ca_attr.payload(self.state_device) - if torch.is_floating_point(data) and data.dtype != torch.float: - data = data.to(torch.float) - p.data = data + p.data = self.master_params[p] ret = self.optim.step(*args, **kwargs) - # Set p.data to None + # Write master param to payload and set p.data to None for group in self.optim.param_groups: for p in group['params']: + # TODO: update payload p.data = None return ret