[fp16] refactored fp16 optimizer (#392)

pull/417/head
Frank Lee 2022-03-15 10:05:38 +08:00 committed by GitHub
parent f8a0e7fb01
commit e79ea44247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 371 additions and 345 deletions

View File

@ -1,13 +1,12 @@
import inspect
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
def convert_to_naive_amp(model: nn.Module, def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with naive AMP modules """A helper function to wrap training components with naive AMP modules
:param model: your model object :param model: your model object
@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
output_to_fp32 = is_no_pp_or_last_stage() output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config) use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
scaler_class = ConstantGradScaler
sig = inspect.signature(scaler_class.__init__)
kwargs = dict()
for param in sig.parameters.values():
if param.name in amp_config:
kwargs[param.name] = amp_config.pop(param.name)
grad_scaler = scaler_class(**kwargs)
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
return model, optimizer return model, optimizer

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch import torch
import torch.distributed as dist
try: try:
import colossal_C import colossal_C
@ -9,41 +10,30 @@ except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer') print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) from torch.distributed import ProcessGroup
from .grad_scaler import BaseGradScaler
from ._utils import has_inf_or_nan, zero_gard_by_list
__all__ = ['FP16Optimizer']
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
for param in group:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another. """
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible is not provided, we default back to simple loop copy to be compatible
with bfloat16.""" with bfloat16.
"""
if overflow_buf: if overflow_buf:
overflow_buf.fill_(0) overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy. # Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(colossal_C.multi_tensor_scale, multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
overflow_buf,
[this, that],
1.0)
else: else:
for this_, that_ in zip(this, that): for this_, that_ in zip(this, that):
that_.copy_(this_) that_.copy_(this_)
@ -111,8 +101,7 @@ class DynamicGradScaler:
self._hysteresis_tracker -= 1 self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss. # Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0: if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
self.min_scale)
if self.verbose: if self.verbose:
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else: else:
@ -127,12 +116,13 @@ class DynamicGradScaler:
if self._max_scale is not None and self._scale >= self._max_scale: if self._max_scale is not None and self._scale >= self._max_scale:
if self.verbose: if self.verbose:
self._logger.info( self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed',
ranks=[0])
else: else:
self._scale = self._scale * self.growth_factor self._scale = self._scale * self.growth_factor
if self.verbose: if self.verbose:
self._logger.info( self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}',
f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) ranks=[0])
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
""" """
def __init__(self, def __init__(self,
optimizer, optimizer: Optimizer,
clip_grad=0, grad_scaler: BaseGradScaler,
log_num_zeros_in_grad=False, verbose: bool = False,
initial_scale=2 ** 32, clip_grad_norm=0,
min_scale=1, dp_process_group: ProcessGroup = None,
growth_factor=2, mp_process_group: ProcessGroup = None):
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale: int = 2 ** 32,
verbose: bool = False):
# default args for compatibility
bf16 = False
params_have_main_grad = False
# have a defaults for compatibility with pytorch optim # have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults self._optimizer = optimizer
self._defaults = optimizer.defaults
# log config # fp16-related params
self._logger = get_dist_logger() assert isinstance(grad_scaler, BaseGradScaler)
if verbose: self._grad_scaler = grad_scaler
self._logger.info(f"\n========= FP16 Optimizer Config =========\n" self._found_overflow = torch.cuda.FloatTensor([0.0])
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n"
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
f"initial_scale = {initial_scale}\n"
f"min_scale = {min_scale}\n"
f"growth_factor = {growth_factor}\n"
f"backoff_factor = {backoff_factor}\n"
f"growth_interval = {growth_interval}\n"
f"hysteresis = {hysteresis}\n"
f"==========================================", ranks=[0])
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
# Set gradient clipping and logging params.
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
self.bf16 = bf16
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose
)
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if self.grad_scaler:
self.found_inf = torch.cuda.FloatTensor([0.0])
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if bf16:
self._dummy_overflow_buf = None
else:
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# In case grad scaler is not passed, define the unity scale. # misc params
if self.grad_scaler is None: self._clip_grad_max_norm = clip_grad_norm
self._scale_one = torch.cuda.FloatTensor([1.0])
# ====================== # get process group
# main parameter stuff def _get_process_group(parallel_mode):
# ====================== if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
return gpc.get_group(ParallelMode.DATA)
else:
return None
# Three groups of parameters: if dp_process_group is None:
# float16_groups: original float16 parameters dp_process_group = _get_process_group(ParallelMode.DATA)
# fp32_from_float16_groups: fp32 copy of float16 parameters if mp_process_group is None:
# fp32_from_fp32_groups: original fp32 parameters mp_process_group = _get_process_group(ParallelMode.MODEL)
self.float16_groups = []
self.fp32_from_float16_groups = [] self._dp_process_group = dp_process_group
self.fp32_from_fp32_groups = [] self._mp_process_group = mp_process_group
# we maintain three groups of parameters
# so that the model can have a mixture
# of fp16 and fp32 params
# fp16_param_groups: the fp16 params of the model
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
# fp32_param_groups: the fp32 params of the model
# NOTE:
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
self._fp16_param_groups = []
self._fp32_master_param_groups = []
self._fp32_param_groups = []
# For all the groups in the original optimizer: # For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups: for param_group in self._optimizer.param_groups:
float16_params_this_group = [] fp16_params = []
fp32_params_this_group = [] fp32_master_params = []
fp32_from_float16_params_this_group = [] fp32_params = []
# For all the parameters in this group: # For all the parameters in this group:
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group['params']):
if param.requires_grad: if param.requires_grad:
# float16 params: # float16 params:
if param.type() in ['torch.cuda.HalfTensor', if param.type() in ['torch.cuda.HalfTensor']:
'torch.cuda.BFloat16Tensor']: fp16_params.append(param)
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes(param, main_param)
# if hasattr(param, 'shared'): # Create a fp32 copy
# main_param.shared = param.shared fp32_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes(param, fp32_param)
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param param_group['params'][i] = fp32_param
fp32_from_float16_params_this_group.append(main_param) fp32_master_params.append(fp32_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self._optimizer.state:
self.optimizer.state[main_param] \ self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
= self.optimizer.state.pop(param)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param) fp32_params.append(param)
param_group['params'][i] = param
else: else:
raise TypeError('Wrapped parameters must be one of ' raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
'torch.cuda.FloatTensor, ' f'or torch.cuda.HalfTensor, but got {param.type()}')
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
self.float16_groups.append(float16_params_this_group) self._fp16_param_groups.append(fp16_params)
self.fp32_from_float16_groups.append( self._fp32_master_param_groups.append(fp32_master_params)
fp32_from_float16_params_this_group) self._fp32_param_groups.append(fp32_params)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self._optimizer.load_state_dict(self._optimizer.state_dict())
def zero_grad(self, set_to_none=False): # log config
"""We only need to zero the model related parameters, i.e., self._logger = get_dist_logger()
float16_groups & fp32_from_fp32_groups.""" if verbose:
for group in self.float16_groups: self._logger.info(
_zero_grad_group_helper(group, set_to_none) f"\n========= FP16 Optimizer Config =========\n"
for group in self.fp32_from_fp32_groups: f"Optimizer: {optimizer.__class__.__name__}\n"
_zero_grad_group_helper(group, set_to_none) f"clip_grad_norm = {clip_grad_norm}\n"
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
f"==========================================",
ranks=[0])
def get_loss_scale(self): @property
if self.grad_scaler is None: def grad_scaler(self):
return self._scale_one return self._grad_scaler
return self.grad_scaler.scale
def _copy_model_grads_to_main_grads(self): @property
def loss_scale(self):
return self._grad_scaler.scale
@property
def optimizer(self):
return self._optimizer
@property
def defaults(self):
return self._defaults
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group in self._optimizer.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
if self._dp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
# all-reduce over model parallel group
if self._mp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
return self._found_overflow.item() > 0
def zero_grad(self, set_to_none=True):
# set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
def _get_fp32_param_groups_to_update(self):
return self._fp32_master_param_groups + self._fp32_param_groups
def _unscale_grads(self):
for group in self._get_fp32_param_groups_to_update():
for p in group:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
def _assign_grad_to_fp32_master_param(self):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
self.fp32_from_float16_groups): for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
for model_param, main_param in zip(model_group, main_group): fp32_param.grad = fp16_param.grad.float()
if self.params_have_main_grad: # clear unneeded grad on fp16 param
main_param.grad = model_param.main_grad.float() fp16_param.grad = None
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad. def _update_fp16_param_from_fp32_param(self):
if self.params_have_main_grad: fp16_param_data = []
for model_group in self.fp32_from_fp32_groups: fp32_master_param_data = []
for model_param in model_group: for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
model_param.grad = model_param.main_grad for fp16_param, fp32_param in zip(fp16_group, fp32_group):
fp16_param_data.append(fp16_param.data)
def _unscale_main_grads_and_check_for_nan(self): fp32_master_param_data.append(fp32_param.data)
main_grads = [] _multi_tensor_copy_this_to_that(this=fp32_master_param_data,
# fp32 params fromm float16 ones. that=fp16_param_data,
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Append fp32 parameters.
for main_group in self.fp32_from_fp32_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL))
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
return found_inf_flag
def _get_model_and_main_params_data_float16(self):
model_data = []
main_data = []
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self): def step(self):
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads() self._assign_grad_to_fp32_master_param()
self._unscale_grads()
# Do unscale, check for inf, and update grad scaler only for overflow = self._check_overflow()
# the case that grad scaler is provided. self._grad_scaler.update(overflow)
if self.grad_scaler:
# Unscale and check for inf/nan. if overflow:
found_inf_flag = self._unscale_main_grads_and_check_for_nan() self.zero_grad()
return False, None
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None, None
# Clip the main gradients. # Clip the main gradients.
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self._clip_grad_max_norm > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer. # Step the optimizer.
self.optimizer.step() self._optimizer.step()
# Update params from main params. # Update params from main params.
self._copy_main_params_to_model_params() self._update_fp16_param_from_fp32_param()
# Successful update. # Successful update.
return True, grad_norm, num_zeros_in_grad return True, grad_norm
def backward(self, loss):
scaled_loss = loss * self.grad_scaler.scale
scaled_loss.backward()
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self._optimizer.state_dict()
if self.grad_scaler: if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
# Optimizer. # Optimizer.
optimizer_key = 'optimizer' self._optimizer.load_state_dict(state_dict['optimizer'])
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler. # Grad scaler.
if 'grad_scaler' not in state_dict: if 'grad_scaler' in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params. # Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params' if 'fp32_master_param_groups' in state_dict:
if fp32_from_float16_params_key not in state_dict: for current_group, ckpt_group in zip(self._fp32_master_param_groups,
fp32_from_float16_params_key = 'fp32_from_fp16' state_dict['fp32_master_param_groups']):
for current_group, saved_group in zip( for current_param, ckpt_param in zip(current_group, ckpt_group):
self.fp32_from_float16_groups, current_param.data.copy_(ckpt_param.data)
state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
def get_parameters(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def clip_grad_norm(self, clip_grad): def clip_grad_norm(self, clip_grad):
params = self.get_parameters() params = []
for param_group in self._optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale() * loss
# Promote state so it can be retrieved or set via # Promote state so it can be retrieved or set via
# "optimizer_instance.state" # "optimizer_instance.state"
def _get_state(self): def _get_state(self):
return self.optimizer.state return self._optimizer.state
def _set_state(self, value): def _set_state(self, value):
self.optimizer.state = value self._optimizer.state = value
state = property(_get_state, _set_state) state = property(_get_state, _set_state)
@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
# "optimizer_instance.param_groups" # "optimizer_instance.param_groups"
# (for example, to adjust the learning rate) # (for example, to adjust the learning rate)
def _get_param_groups(self): def _get_param_groups(self):
return self.optimizer.param_groups return self._optimizer.param_groups
def _set_param_groups(self, value): def _set_param_groups(self, value):
self.optimizer.param_groups = value self._optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups) param_groups = property(_get_param_groups, _set_param_groups)

View File

@ -0,0 +1,40 @@
from typing import List
from torch import Tensor
def has_inf_or_nan(tensor):
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum = float(tensor.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
return True
return False
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
"""
Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for param in tensor_list:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()

View File

@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
def inv_scale(self) -> Tensor: def inv_scale(self) -> Tensor:
return self._scale.double().reciprocal().float() return self._scale.double().reciprocal().float()
@abstractmethod
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
state_dict = dict() state_dict = dict()
state_dict['scale'] = self.scale state_dict['scale'] = self.scale
@abstractmethod
def load_state_dict(self, state_dict: Dict) -> None: def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale'] self._scale = state_dict['scale']

View File

@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
growth_interval: int = 1000, growth_interval: int = 1000,
min_scale: int = None, min_scale: int = None,
max_scale: int = None, max_scale: int = None,
hysteresis: int = None, hysteresis: int = 2,
verbose: bool = False): verbose: bool = False):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
self._min_scale = min_scale if min_scale:
self._max_scale = max_scale self._min_scale = torch.cuda.FloatTensor([min_scale])
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
else:
self._max_scale = None
self._growth_factor = growth_factor self._growth_factor = growth_factor
self._backoff_factor = backoff_factor self._backoff_factor = backoff_factor
self._growth_interval = growth_interval self._growth_interval = growth_interval

View File

@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs) optim = FP16Optimizer(optim, *args, **kwargs)
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Backward with gradient scaler self.optim.backward(loss)
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self): def step(self):
return self.optim.step() return self.optim.step()

View File

@ -304,7 +304,7 @@ def initialize(model: nn.Module,
if is_using_pp(): if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE: if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm cfg_['clip_grad_norm'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model, model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,

View File

@ -1,4 +1,3 @@
from itertools import groupby
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -7,7 +6,7 @@ from torch.optim import Optimizer
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor, from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan) release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
@ -16,11 +15,8 @@ from functools import partial
class ShardedOptimizer(ColossalaiOptimizer): class ShardedOptimizer(ColossalaiOptimizer):
def __init__( def __init__(self,
self,
optimizer: Optimizer, optimizer: Optimizer,
# grad scaler config
initial_scale=2**32, initial_scale=2**32,
min_scale=1, min_scale=1,
growth_factor=2, growth_factor=2,
@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer):
growth_interval=1000, growth_interval=1000,
hysteresis=2, hysteresis=2,
max_scale: int = 2**32, max_scale: int = 2**32,
# grad clipping
clip_grad_norm=2.0, clip_grad_norm=2.0,
verbose=False, verbose=False,
# communication
reduce_bucket_size=500000000, reduce_bucket_size=500000000,
communication_dtype=torch.float16, communication_dtype=torch.float16,
overlap_communication=False, overlap_communication=False,
# stage 2
partition_grad=False, partition_grad=False,
dp_parallel_mode=ParallelMode.DATA, dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL, mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False, cpu_offload=False,
cpu_fp16_param=False, cpu_fp16_param=False,
cpu_fp16_grad=False): cpu_fp16_grad=False):
@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
# args here is not grad, but allow_unreacable and accumulate_grad # args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): def reduce_grad_hook(*args):
reduction_func() reduction_func()
accum_grad_obj.register_hook(reduce_grad_hook) accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank) _define_and_attach(param, reduce_rank)
@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)

View File

@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer

View File

@ -0,0 +1,83 @@
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.utils import free_port
from functools import partial
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def run_naive_amp():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
# create layer
test_models = ['repeated_computed_layers', 'nested_model']
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
# create model
amp_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(amp_model)
# create optimizer
amp_optimizer = optim_builder(amp_model)
torch_optimizer = optim_builder(torch_model)
# inject naive amp
amp_config = dict(initial_scale=1)
amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
# forward pass
amp_output = amp_model(data)
torch_output = torch_model(data)
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
# backward
amp_optimizer.backward(amp_output.mean())
torch_output.mean().backward()
# check grad
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
# step
amp_optimizer.step()
torch_optimizer.step()
# check updated param
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
run_naive_amp()
@pytest.mark.dist
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_naive_amp()