ColossalAI/colossalai/elixir/wrapper/optimizer.py

266 lines
10 KiB
Python

import math
from collections import defaultdict
from enum import Enum
from typing import Dict, Set, Tuple
import torch
import torch.distributed as dist
from torch.nn import Parameter
import colossalai.nn.optimizer as colo_optim
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler, ConstantGradScaler, DynamicGradScaler
from colossalai.elixir.chunk import Chunk
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.hook.storage import BufferStore
from colossalai.logging import get_dist_logger
from .module import ElixirModule
_AVAIL_OPTIM_LIST = {colo_optim.FusedAdam, colo_optim.CPUAdam, colo_optim.HybridAdam}
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
"""A wrapper for optimizers. Users should notice that one specific ElixirOptimizer is strictly
corresponding to one ElixirModule. Currently only a group of optimizers are supported in ElixirOptimizer.
The reason is that ElixirOptimizer only support element-wise optimizers now.
We may enlarge the group of supported optimizers later.
Args:
optim: The torch optimizer instance.
module: The nn.Module instance wrapped as an ElixirModule.
"""
def __init__(self,
module: ElixirModule,
optimizer: torch.optim.Optimizer,
initial_scale: float = 32768,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**24,
max_norm: float = 0.0,
norm_type: float = 2.0,
init_step=False):
super().__init__(optimizer)
assert isinstance(module, ElixirModule)
self.scaled_optimizer = False
if type(optimizer) in _AVAIL_OPTIM_LIST:
self.scaled_optimizer = True
self.module = module
self.param_chunk_group = module.param_chunk_group
self.optim_chunk_group = module.optim_chunk_group
self.optim_state = OptimState.UNSCALED
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_optim_chunk: Dict[Parameter, Chunk] = dict()
self.param_chunk_set: Set[Chunk] = self.param_chunk_group.fused_chunks.union(
self.param_chunk_group.float_chunks)
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
if self.clipping_flag:
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
self.__init__optimizer()
# Grad scaler
self.grad_scaler: BaseGradScaler = None
if module.use_amp:
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
else:
self.grad_scaler = ConstantGradScaler(1.0, verbose=False)
self._comm_buffer: torch.Tensor = torch.zeros(1, dtype=torch.float, device=gpu_device())
self._logger = get_dist_logger()
if init_step:
# allocate memory before training
self.__zero_step()
def __zero_step(self):
torch.cuda.empty_cache()
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
for _, zero_buffer in buffer_dict.items():
zero_buffer.zeros()
for group in self.param_groups:
for fake_param in group['params']:
optim_chunk = self.param_to_optim_chunk[fake_param]
begin, end = self.param_to_range[fake_param]
fake_param.data = buffer_dict.get(optim_chunk.shard_device.type).empty_1d(end - begin)
fake_param.grad = fake_param.data
fake_param.data = optim_chunk.shard[begin:end]
self.optim.step()
self.zero_grad()
self._update_fp16_params(update_flag=False)
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
optim_chunk = self.param_to_optim_chunk[fake_param]
begin, end = self.param_to_range[fake_param]
param_chunk = optim_chunk.paired_chunk
fake_param.data = param_chunk.shard[begin:end]
fake_param.grad = fake_param.data
fake_param.data = optim_chunk.shard[begin:end]
def _update_fp16_params(self, update_flag: bool = True):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor.to(fake_param.device)
if update_flag:
for param_chunk in self.param_chunk_set:
param_chunk.optim_update()
def _check_overflow(self) -> bool:
# calculate the overflow counter
overflow_counter = 0
for param_chunk in self.param_chunk_set:
overflow_counter += int(param_chunk.overflow)
return overflow_counter > 0
def _clear_optim_states(self) -> None:
for param_chunk in self.param_chunk_set:
param_chunk.overflow = False
param_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
group_to_norm = defaultdict(float)
for param_chunk in self.param_chunk_set:
assert param_chunk.l2_norm is not None
assert not param_chunk.is_replica
group_to_norm[param_chunk.torch_pg] += param_chunk.l2_norm
norm_sqr = 0.0
for group, part_norm in group_to_norm.items():
self._comm_buffer.fill_(part_norm)
dist.all_reduce(self._comm_buffer, group=group)
norm_sqr += self._comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _get_combined_scale(self):
loss_scale = 1
assert self.optim_state == OptimState.SCALED
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._clear_optim_states() # clear chunk states used for optimizer update
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
self._clear_optim_states()
if not self.scaled_optimizer:
assert combined_scale == -1, 'You should use an optimizer in the available list:\n' \
f'{_AVAIL_OPTIM_LIST}'
ret = self.optim.step(*args, **kwargs)
else:
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self.zero_grad()
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
self.module.backward_by_grad(tensor, grad)
def __init__optimizer(self):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
for group in self.param_groups:
fake_params_list = list()
for param in group['params']:
if not param.requires_grad:
continue
param_chunk = self.module.fetcher.get_one_chunk(param)
range_pair = get_range_pair(param_chunk, param)
if range_pair[0] >= range_pair[1]:
continue
grad_device = param_chunk.shard.device
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
self.param_to_range[fake_param] = range_pair
fake_params_list.append(fake_param)
group['params'] = fake_params_list