ColossalAI/colossalai/amp/naive_amp/_fp16_optimizer.py

357 lines
13 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
try:
import colossalai._C.fused_optim
except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
__all__ = ['FP16Optimizer']
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16.
"""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
Note that clipping is ignored if clip_grad == 0
verbose (bool, optional): if set to `True`, will print debug info. Default False.
"""
def __init__(self,
optimizer: Optimizer,
grad_scaler: BaseGradScaler,
verbose: bool = False,
clip_grad_norm=0,
dp_process_group: ProcessGroup = None,
mp_process_group: ProcessGroup = None):
# have a defaults for compatibility with pytorch optim
self._optimizer = optimizer
self._defaults = optimizer.defaults
# fp16-related params
assert isinstance(grad_scaler, BaseGradScaler)
self._grad_scaler = grad_scaler
self._found_overflow = torch.cuda.FloatTensor([0.0])
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# misc params
self._clip_grad_max_norm = clip_grad_norm
# get process group
def _get_process_group(parallel_mode):
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
return gpc.get_group(ParallelMode.DATA)
else:
return None
if dp_process_group is None:
dp_process_group = _get_process_group(ParallelMode.DATA)
if mp_process_group is None:
mp_process_group = _get_process_group(ParallelMode.MODEL)
self._dp_process_group = dp_process_group
self._mp_process_group = mp_process_group
# we maintain three groups of parameters
# so that the model can have a mixture
# of fp16 and fp32 params
# fp16_param_groups: the fp16 params of the model
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
# fp32_param_groups: the fp32 params of the model
# NOTE:
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
self._fp16_param_groups = []
self._fp32_master_param_groups = []
self._fp32_param_groups = []
# For all the groups in the original optimizer:
for param_group in self._optimizer.param_groups:
fp16_params = []
fp32_master_params = []
fp32_params = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor']:
fp16_params.append(param)
# Create a fp32 copy
fp32_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes(param, fp32_param)
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = fp32_param
fp32_master_params.append(fp32_param)
# Reset existing state dict key to the new main param.
if param in self._optimizer.state:
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params.append(param)
else:
raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
f'or torch.cuda.HalfTensor, but got {param.type()}')
self._fp16_param_groups.append(fp16_params)
self._fp32_master_param_groups.append(fp32_master_params)
self._fp32_param_groups.append(fp32_params)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self._optimizer.load_state_dict(self._optimizer.state_dict())
# log config
self._logger = get_dist_logger()
if verbose:
self._logger.info(
f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad_norm = {clip_grad_norm}\n"
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
f"==========================================",
ranks=[0])
@property
def grad_scaler(self):
"""Returns the gradient scaler.
Returns:
:class:`BaseGradScaler`: gradient scaler.
"""
return self._grad_scaler
@property
def loss_scale(self):
"""Returns the loss scale.
Returns:
int: loss scale.
"""
return self._grad_scaler.scale
@property
def optimizer(self):
"""Returns the optimizer.
Returns:
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
"""
return self._optimizer
@property
def defaults(self):
"""Returns the default arguments of optimizer.
Returns:
dict: optimizer arguments saved in defaults of the optimizer wrapped.
"""
return self._defaults
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is not None and has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
if self._dp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
# all-reduce over model parallel group
if self._mp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
return self._found_overflow.item() > 0
def zero_grad(self, set_to_none=True):
"""Set gradient to zero.
Args:
set_to_none (bool): Whether set the gradient to None.
"""
# set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
def _get_fp32_param_groups_to_update(self):
return self._fp32_master_param_groups + self._fp32_param_groups
def _unscale_grads(self):
for group in self._get_fp32_param_groups_to_update():
for p in group:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
def _assign_grad_to_fp32_master_param(self):
# This only needs to be done for the float16 group.
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
if fp16_param.grad is not None:
fp32_param.grad = fp16_param.grad.float()
# clear unneeded grad on fp16 param
fp16_param.grad = None
def _update_fp16_param_from_fp32_param(self):
fp16_param_data = []
fp32_master_param_data = []
for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
fp16_param_data.append(fp16_param.data)
fp32_master_param_data.append(fp32_param.data)
_multi_tensor_copy_this_to_that(this=fp32_master_param_data,
that=fp16_param_data,
overflow_buf=self._dummy_overflow_buf)
def step(self):
"""Update the model parameters.
"""
# Copy gradients from model params to main params.
self._assign_grad_to_fp32_master_param()
self._unscale_grads()
overflow = self._check_overflow()
self._grad_scaler.update(overflow)
if overflow:
self.zero_grad()
# Clip the main gradients.
grad_norm = None
if self._clip_grad_max_norm > 0.0:
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
if not overflow:
# Step the optimizer.
self._optimizer.step()
# Update params from main params.
self._update_fp16_param_from_fp32_param()
# Successful update.
return True, grad_norm
else:
return False, None
def backward(self, loss):
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
scaled_loss = loss * self.grad_scaler.scale
scaled_loss.backward()
def state_dict(self):
"""Returns the states of the fp16 optimizer as a dict object.
"""
state_dict = {}
state_dict['optimizer'] = self._optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
return state_dict
def load_state_dict(self, state_dict):
"""Load the states of the fp16 optimizer from a dict object.
Args:
state_dict (dict): the states of the fp16 optimizer
"""
# Optimizer.
self._optimizer.load_state_dict(state_dict['optimizer'])
# Grad scaler.
if 'grad_scaler' in state_dict:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the main params.
if 'fp32_master_param_groups' in state_dict:
for current_group, ckpt_group in zip(self._fp32_master_param_groups,
state_dict['fp32_master_param_groups']):
for current_param, ckpt_param in zip(current_group, ckpt_group):
current_param.data.copy_(ckpt_param.data)
def clip_grad_norm(self, clip_grad):
"""Clip gradients by norm.
Args:
clip_grad (float): the max norm for clipping
"""
params = []
for param_group in self._optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return clip_grad_norm_fp32(params, clip_grad)
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self._optimizer.state
def _set_state(self, value):
self._optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self._optimizer.param_groups
def _set_param_groups(self, value):
self._optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)