add fp32 master params in sharded adam

pull/394/head
ver217 2022-03-03 15:42:53 +08:00 committed by Frank Lee
parent a109225bc2
commit 795210dd99
1 changed files with 14 additions and 34 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, Union
from typing import Dict, Optional, Union
import torch
import torch.distributed as dist
@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from ._utils import has_inf_or_nan
@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer):
super().__init__(adam_optim)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer):
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
# Early state initialization
for group in adam_optim.param_groups:
for p in group['params']:
state_shape = p.shape
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
# TODO: use payload shape
state_shape = p.ca_attr.payload(self.state_device)
state = adam_optim.state[p]
assert len(state) == 0, 'adam optimizer initialized'
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
else:
self.master_params[p] = p.data.to(torch.float)
def step(self, *args, **kwargs):
# unscale grads if scaled
@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer):
self.zero_grad()
return
# Write payload back to p.data
# Write master param to p.data
for group in self.optim.param_groups:
for p in group['params']:
data = p.data
if hasattr(p, 'ca_attr'):
data = p.ca_attr.payload(self.state_device)
if torch.is_floating_point(data) and data.dtype != torch.float:
data = data.to(torch.float)
p.data = data
p.data = self.master_params[p]
ret = self.optim.step(*args, **kwargs)
# Set p.data to None
# Write master param to payload and set p.data to None
for group in self.optim.param_groups:
for p in group['params']:
# TODO: update payload
p.data = None
return ret