mirror of https://github.com/hpcaitech/ColossalAI
266 lines
10 KiB
Python
266 lines
10 KiB
Python
import math
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import Dict, Set, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.nn import Parameter
|
|
|
|
import colossalai.nn.optimizer as colo_optim
|
|
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler, ConstantGradScaler, DynamicGradScaler
|
|
from colossalai.elixir.chunk import Chunk
|
|
from colossalai.elixir.cuda import gpu_device
|
|
from colossalai.elixir.hook.storage import BufferStore
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from .module import ElixirModule
|
|
|
|
_AVAIL_OPTIM_LIST = {colo_optim.FusedAdam, colo_optim.CPUAdam, colo_optim.HybridAdam}
|
|
|
|
|
|
class OptimState(Enum):
|
|
SCALED = 0
|
|
UNSCALED = 1
|
|
|
|
|
|
class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|
"""A wrapper for optimizers. Users should notice that one specific ElixirOptimizer is strictly
|
|
corresponding to one ElixirModule. Currently only a group of optimizers are supported in ElixirOptimizer.
|
|
The reason is that ElixirOptimizer only support element-wise optimizers now.
|
|
We may enlarge the group of supported optimizers later.
|
|
|
|
Args:
|
|
optim: The torch optimizer instance.
|
|
module: The nn.Module instance wrapped as an ElixirModule.
|
|
"""
|
|
|
|
def __init__(self,
|
|
module: ElixirModule,
|
|
optimizer: torch.optim.Optimizer,
|
|
initial_scale: float = 32768,
|
|
min_scale: float = 1,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
hysteresis: int = 2,
|
|
max_scale: float = 2**24,
|
|
max_norm: float = 0.0,
|
|
norm_type: float = 2.0,
|
|
init_step=False):
|
|
|
|
super().__init__(optimizer)
|
|
assert isinstance(module, ElixirModule)
|
|
self.scaled_optimizer = False
|
|
if type(optimizer) in _AVAIL_OPTIM_LIST:
|
|
self.scaled_optimizer = True
|
|
|
|
self.module = module
|
|
self.param_chunk_group = module.param_chunk_group
|
|
self.optim_chunk_group = module.optim_chunk_group
|
|
|
|
self.optim_state = OptimState.UNSCALED
|
|
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
|
self.param_to_optim_chunk: Dict[Parameter, Chunk] = dict()
|
|
self.param_chunk_set: Set[Chunk] = self.param_chunk_group.fused_chunks.union(
|
|
self.param_chunk_group.float_chunks)
|
|
|
|
self.clipping_flag = max_norm > 0.0
|
|
self.max_norm = max_norm
|
|
if self.clipping_flag:
|
|
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
|
|
|
|
self.__init__optimizer()
|
|
|
|
# Grad scaler
|
|
self.grad_scaler: BaseGradScaler = None
|
|
if module.use_amp:
|
|
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
|
min_scale=min_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
max_scale=max_scale)
|
|
else:
|
|
self.grad_scaler = ConstantGradScaler(1.0, verbose=False)
|
|
self._comm_buffer: torch.Tensor = torch.zeros(1, dtype=torch.float, device=gpu_device())
|
|
self._logger = get_dist_logger()
|
|
|
|
if init_step:
|
|
# allocate memory before training
|
|
self.__zero_step()
|
|
|
|
def __zero_step(self):
|
|
torch.cuda.empty_cache()
|
|
|
|
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
|
|
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
|
|
for _, zero_buffer in buffer_dict.items():
|
|
zero_buffer.zeros()
|
|
|
|
for group in self.param_groups:
|
|
for fake_param in group['params']:
|
|
optim_chunk = self.param_to_optim_chunk[fake_param]
|
|
begin, end = self.param_to_range[fake_param]
|
|
|
|
fake_param.data = buffer_dict.get(optim_chunk.shard_device.type).empty_1d(end - begin)
|
|
fake_param.grad = fake_param.data
|
|
fake_param.data = optim_chunk.shard[begin:end]
|
|
|
|
self.optim.step()
|
|
self.zero_grad()
|
|
self._update_fp16_params(update_flag=False)
|
|
|
|
def _set_grad_ptr(self):
|
|
for group in self.param_groups:
|
|
for fake_param in group['params']:
|
|
optim_chunk = self.param_to_optim_chunk[fake_param]
|
|
begin, end = self.param_to_range[fake_param]
|
|
param_chunk = optim_chunk.paired_chunk
|
|
|
|
fake_param.data = param_chunk.shard[begin:end]
|
|
fake_param.grad = fake_param.data
|
|
fake_param.data = optim_chunk.shard[begin:end]
|
|
|
|
def _update_fp16_params(self, update_flag: bool = True):
|
|
none_tensor = torch.empty([0])
|
|
for group in self.param_groups:
|
|
for fake_param in group['params']:
|
|
assert fake_param.grad is None
|
|
fake_param.data = none_tensor.to(fake_param.device)
|
|
|
|
if update_flag:
|
|
for param_chunk in self.param_chunk_set:
|
|
param_chunk.optim_update()
|
|
|
|
def _check_overflow(self) -> bool:
|
|
# calculate the overflow counter
|
|
overflow_counter = 0
|
|
for param_chunk in self.param_chunk_set:
|
|
overflow_counter += int(param_chunk.overflow)
|
|
return overflow_counter > 0
|
|
|
|
def _clear_optim_states(self) -> None:
|
|
for param_chunk in self.param_chunk_set:
|
|
param_chunk.overflow = False
|
|
param_chunk.l2_norm = None
|
|
|
|
def _calc_global_norm(self) -> float:
|
|
group_to_norm = defaultdict(float)
|
|
for param_chunk in self.param_chunk_set:
|
|
assert param_chunk.l2_norm is not None
|
|
assert not param_chunk.is_replica
|
|
|
|
group_to_norm[param_chunk.torch_pg] += param_chunk.l2_norm
|
|
|
|
norm_sqr = 0.0
|
|
for group, part_norm in group_to_norm.items():
|
|
self._comm_buffer.fill_(part_norm)
|
|
dist.all_reduce(self._comm_buffer, group=group)
|
|
norm_sqr += self._comm_buffer.item()
|
|
|
|
global_norm = math.sqrt(norm_sqr)
|
|
return global_norm
|
|
|
|
def _get_combined_scale(self):
|
|
loss_scale = 1
|
|
|
|
assert self.optim_state == OptimState.SCALED
|
|
loss_scale = self.loss_scale
|
|
self.optim_state = OptimState.UNSCALED
|
|
|
|
combined_scale = loss_scale
|
|
if self.clipping_flag:
|
|
total_norm = self._calc_global_norm()
|
|
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
|
if clip > 1:
|
|
combined_scale = clip * loss_scale
|
|
|
|
if combined_scale == 1:
|
|
return -1
|
|
else:
|
|
return combined_scale
|
|
|
|
@property
|
|
def loss_scale(self):
|
|
return self.grad_scaler.scale.item()
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
return self.optim.zero_grad(set_to_none=True)
|
|
|
|
def step(self, *args, **kwargs):
|
|
self._set_grad_ptr()
|
|
found_inf = self._check_overflow()
|
|
|
|
if found_inf:
|
|
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
|
self.grad_scaler.update(found_inf) # update gradient scaler
|
|
self._logger.info(f'Found overflow. Skip step')
|
|
self._clear_optim_states() # clear chunk states used for optimizer update
|
|
self.zero_grad() # reset all gradients
|
|
self._update_fp16_params()
|
|
return
|
|
|
|
# get combined scale. combined scale = loss scale * clipping norm
|
|
# so that gradient = gradient / combined scale
|
|
combined_scale = self._get_combined_scale()
|
|
self.grad_scaler.update(found_inf)
|
|
self._clear_optim_states()
|
|
|
|
if not self.scaled_optimizer:
|
|
assert combined_scale == -1, 'You should use an optimizer in the available list:\n' \
|
|
f'{_AVAIL_OPTIM_LIST}'
|
|
ret = self.optim.step(*args, **kwargs)
|
|
else:
|
|
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
|
|
|
self.zero_grad()
|
|
self._update_fp16_params()
|
|
return ret
|
|
|
|
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
loss = self.loss_scale * loss
|
|
self.optim_state = OptimState.SCALED
|
|
self.module.backward(loss)
|
|
|
|
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
|
# This function is called except the last stage of pipeline parallel
|
|
# It receives the scaled grad from the previous rank
|
|
# No need to scale the grad again
|
|
# Need to unscale when optimizing
|
|
self.optim_state = OptimState.SCALED
|
|
self.module.backward_by_grad(tensor, grad)
|
|
|
|
def __init__optimizer(self):
|
|
|
|
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
|
param_info = local_chunk.tensors_info[local_param]
|
|
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
|
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
|
return begin, end
|
|
|
|
for group in self.param_groups:
|
|
fake_params_list = list()
|
|
|
|
for param in group['params']:
|
|
if not param.requires_grad:
|
|
continue
|
|
|
|
param_chunk = self.module.fetcher.get_one_chunk(param)
|
|
range_pair = get_range_pair(param_chunk, param)
|
|
if range_pair[0] >= range_pair[1]:
|
|
continue
|
|
|
|
grad_device = param_chunk.shard.device
|
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
|
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
|
|
self.param_to_range[fake_param] = range_pair
|
|
|
|
fake_params_list.append(fake_param)
|
|
|
|
group['params'] = fake_params_list
|