ColossalAI/colossalai/zero/sharded_optim/sharded_optim_v2.py

144 lines
5.5 KiB
Python
Raw Normal View History

2022-03-03 07:06:18 +00:00
from enum import Enum
2022-03-03 07:42:53 +00:00
from typing import Dict, Optional, Union
2022-03-03 07:06:18 +00:00
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from torch import Tensor
from torch.distributed import ProcessGroup
2022-03-03 07:42:53 +00:00
from torch.nn.parameter import Parameter
2022-03-03 07:06:18 +00:00
from torch.optim import Optimizer
from ._utils import has_inf_or_nan
class OptimState(Enum):
SCALED = 1
UNSCALED = 2
2022-03-03 07:55:27 +00:00
class ShardedOptimizerV2(ColossalaiOptimizer):
2022-03-03 07:06:18 +00:00
def __init__(self,
2022-03-04 05:44:38 +00:00
optimizer: Optimizer,
2022-03-03 07:50:30 +00:00
sharded_model: Union[nn.Module, ShardedModelV2],
2022-03-03 07:06:18 +00:00
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
2022-03-04 05:44:38 +00:00
super().__init__(optimizer)
2022-03-03 07:06:18 +00:00
self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
2022-03-03 07:42:53 +00:00
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
2022-03-03 07:06:18 +00:00
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
2022-03-03 07:42:53 +00:00
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
2022-03-03 07:06:18 +00:00
2022-03-04 05:44:38 +00:00
for group in optimizer.param_groups:
2022-03-03 07:06:18 +00:00
for p in group['params']:
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
2022-03-03 07:50:30 +00:00
self.master_params[p] = p.ca_attr.payload(self.device)
2022-03-03 07:42:53 +00:00
else:
2022-03-03 07:50:30 +00:00
self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
self.master_params[p] = self.master_params[p].to(torch.float)
2022-03-03 07:06:18 +00:00
def step(self, *args, **kwargs):
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.zero_grad()
return
2022-03-03 07:42:53 +00:00
# Write master param to p.data
2022-03-03 07:06:18 +00:00
for group in self.optim.param_groups:
for p in group['params']:
2022-03-03 07:42:53 +00:00
p.data = self.master_params[p]
2022-03-03 07:06:18 +00:00
ret = self.optim.step(*args, **kwargs)
2022-03-04 03:49:02 +00:00
# Write master param to payload
2022-03-03 07:06:18 +00:00
for group in self.optim.param_groups:
for p in group['params']:
2022-03-03 07:50:30 +00:00
if hasattr(p, 'ca_attr'):
2022-03-04 03:49:02 +00:00
p.ca_attr.set_payload(p.data)
2022-03-04 05:40:48 +00:00
p.data = p.ca_attr.payload()
2022-03-03 07:06:18 +00:00
return ret
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if self.model_is_sharded:
self.model.backward(loss)
else:
super().backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
if self.model_is_sharded:
self.model.backward_by_grad(tensor, grad)
else:
super().backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
@property
def loss_scale(self):
return self.grad_scaler.scale
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group in self.optim.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.dp_process_group)
# all-reduce over model parallel group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
2022-03-03 07:50:30 +00:00
return self._found_overflow.item() > 0
2022-03-03 07:06:18 +00:00
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED