mirror of https://github.com/hpcaitech/ColossalAI
[fp16] refactored fp16 optimizer (#392)
parent
f8a0e7fb01
commit
e79ea44247
|
@ -1,13 +1,12 @@
|
||||||
|
import inspect
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from colossalai.utils import is_no_pp_or_last_stage
|
from colossalai.utils import is_no_pp_or_last_stage
|
||||||
|
|
||||||
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
|
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
|
||||||
|
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
|
||||||
|
|
||||||
|
|
||||||
def convert_to_naive_amp(model: nn.Module,
|
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||||
optimizer: Optimizer,
|
|
||||||
amp_config):
|
|
||||||
"""A helper function to wrap training components with naive AMP modules
|
"""A helper function to wrap training components with naive AMP modules
|
||||||
|
|
||||||
:param model: your model object
|
:param model: your model object
|
||||||
|
@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module,
|
||||||
output_to_fp32 = is_no_pp_or_last_stage()
|
output_to_fp32 = is_no_pp_or_last_stage()
|
||||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||||
|
|
||||||
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
|
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
||||||
|
if use_dynamic_grad_scaler:
|
||||||
|
scaler_class = DynamicGradScaler
|
||||||
|
else:
|
||||||
|
scaler_class = ConstantGradScaler
|
||||||
|
|
||||||
|
sig = inspect.signature(scaler_class.__init__)
|
||||||
|
kwargs = dict()
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.name in amp_config:
|
||||||
|
kwargs[param.name] = amp_config.pop(param.name)
|
||||||
|
grad_scaler = scaler_class(**kwargs)
|
||||||
|
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossal_C
|
||||||
|
@ -9,41 +10,30 @@ except:
|
||||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
|
from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
|
||||||
clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
|
from torch.distributed import ProcessGroup
|
||||||
|
from .grad_scaler import BaseGradScaler
|
||||||
|
from ._utils import has_inf_or_nan, zero_gard_by_list
|
||||||
|
|
||||||
|
__all__ = ['FP16Optimizer']
|
||||||
def _zero_grad_group_helper(group, set_to_none):
|
|
||||||
"""Zero out the gradient for a group of parameters.
|
|
||||||
Note: copied from torch.optim.optimizer."""
|
|
||||||
for param in group:
|
|
||||||
if param.grad is not None:
|
|
||||||
if set_to_none:
|
|
||||||
param.grad = None
|
|
||||||
else:
|
|
||||||
if param.grad.grad_fn is not None:
|
|
||||||
param.grad.detach_()
|
|
||||||
else:
|
|
||||||
param.grad.requires_grad_(False)
|
|
||||||
param.grad.zero_()
|
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||||
"""Use multi-tensor-applier to copy values from one list to another.
|
"""
|
||||||
|
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
|
||||||
|
|
||||||
|
Use multi-tensor-applier to copy values from one list to another.
|
||||||
We don't have a blfoat16 implementation so for now if the overflow_buf
|
We don't have a blfoat16 implementation so for now if the overflow_buf
|
||||||
is not provided, we default back to simple loop copy to be compatible
|
is not provided, we default back to simple loop copy to be compatible
|
||||||
with bfloat16."""
|
with bfloat16.
|
||||||
|
"""
|
||||||
if overflow_buf:
|
if overflow_buf:
|
||||||
overflow_buf.fill_(0)
|
overflow_buf.fill_(0)
|
||||||
# Scaling with factor `1.0` is equivalent to copy.
|
# Scaling with factor `1.0` is equivalent to copy.
|
||||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||||
overflow_buf,
|
|
||||||
[this, that],
|
|
||||||
1.0)
|
|
||||||
else:
|
else:
|
||||||
for this_, that_ in zip(this, that):
|
for this_, that_ in zip(this, that):
|
||||||
that_.copy_(this_)
|
that_.copy_(this_)
|
||||||
|
@ -111,8 +101,7 @@ class DynamicGradScaler:
|
||||||
self._hysteresis_tracker -= 1
|
self._hysteresis_tracker -= 1
|
||||||
# Now if we are out of hysteresis count, scale down the loss.
|
# Now if we are out of hysteresis count, scale down the loss.
|
||||||
if self._hysteresis_tracker <= 0:
|
if self._hysteresis_tracker <= 0:
|
||||||
self._scale = torch.max(self._scale * self.backoff_factor,
|
self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
|
||||||
self.min_scale)
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
|
||||||
else:
|
else:
|
||||||
|
@ -127,12 +116,13 @@ class DynamicGradScaler:
|
||||||
if self._max_scale is not None and self._scale >= self._max_scale:
|
if self._max_scale is not None and self._scale >= self._max_scale:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
|
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed',
|
||||||
|
ranks=[0])
|
||||||
else:
|
else:
|
||||||
self._scale = self._scale * self.growth_factor
|
self._scale = self._scale * self.growth_factor
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._logger.info(
|
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}',
|
||||||
f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer,
|
optimizer: Optimizer,
|
||||||
clip_grad=0,
|
grad_scaler: BaseGradScaler,
|
||||||
log_num_zeros_in_grad=False,
|
verbose: bool = False,
|
||||||
initial_scale=2 ** 32,
|
clip_grad_norm=0,
|
||||||
min_scale=1,
|
dp_process_group: ProcessGroup = None,
|
||||||
growth_factor=2,
|
mp_process_group: ProcessGroup = None):
|
||||||
backoff_factor=0.5,
|
|
||||||
growth_interval=1000,
|
|
||||||
hysteresis=2,
|
|
||||||
max_scale: int = 2 ** 32,
|
|
||||||
verbose: bool = False):
|
|
||||||
# default args for compatibility
|
|
||||||
bf16 = False
|
|
||||||
params_have_main_grad = False
|
|
||||||
|
|
||||||
# have a defaults for compatibility with pytorch optim
|
# have a defaults for compatibility with pytorch optim
|
||||||
self.defaults = optimizer.defaults
|
self._optimizer = optimizer
|
||||||
|
self._defaults = optimizer.defaults
|
||||||
|
|
||||||
# log config
|
# fp16-related params
|
||||||
self._logger = get_dist_logger()
|
assert isinstance(grad_scaler, BaseGradScaler)
|
||||||
if verbose:
|
self._grad_scaler = grad_scaler
|
||||||
self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
|
self._found_overflow = torch.cuda.FloatTensor([0.0])
|
||||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
|
||||||
f"clip_grad = {clip_grad}\n"
|
|
||||||
f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n"
|
|
||||||
f"initial_scale = {initial_scale}\n"
|
|
||||||
f"min_scale = {min_scale}\n"
|
|
||||||
f"growth_factor = {growth_factor}\n"
|
|
||||||
f"backoff_factor = {backoff_factor}\n"
|
|
||||||
f"growth_interval = {growth_interval}\n"
|
|
||||||
f"hysteresis = {hysteresis}\n"
|
|
||||||
f"==========================================", ranks=[0])
|
|
||||||
|
|
||||||
"""Input optimizer is the base optimizer for example Adam."""
|
|
||||||
self.optimizer = optimizer
|
|
||||||
assert self.optimizer, 'no optimizer is provided.'
|
|
||||||
# Set gradient clipping and logging params.
|
|
||||||
self.clip_grad = clip_grad
|
|
||||||
self.log_num_zeros_in_grad = log_num_zeros_in_grad
|
|
||||||
self.params_have_main_grad = params_have_main_grad
|
|
||||||
|
|
||||||
self.bf16 = bf16
|
|
||||||
self.grad_scaler = DynamicGradScaler(
|
|
||||||
initial_scale=initial_scale,
|
|
||||||
min_scale=min_scale,
|
|
||||||
growth_factor=growth_factor,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
growth_interval=growth_interval,
|
|
||||||
hysteresis=hysteresis,
|
|
||||||
max_scale=max_scale,
|
|
||||||
verbose=verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
# None grad scaler is only supported for bf16.
|
|
||||||
if self.grad_scaler is None:
|
|
||||||
assert self.bf16, 'fp16 expects a grad scaler.'
|
|
||||||
|
|
||||||
# Tensor used to determine if a nan/if has happend.
|
|
||||||
# Any non-zero value indicates inf/nan.
|
|
||||||
# Note that we keep this for the cases that grad scaler is none.
|
|
||||||
# We still record nan/inf if we have a bfloat16 with a grad scaler.
|
|
||||||
if self.grad_scaler:
|
|
||||||
self.found_inf = torch.cuda.FloatTensor([0.0])
|
|
||||||
|
|
||||||
# Dummy tensor needed for apex multi-apply tensor.
|
|
||||||
# For bfloat, we don't have multi-tensor apply and for now
|
|
||||||
# we set it to none so the multi-tensor apply gets ignored.
|
|
||||||
if bf16:
|
|
||||||
self._dummy_overflow_buf = None
|
|
||||||
else:
|
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
# In case grad scaler is not passed, define the unity scale.
|
# misc params
|
||||||
if self.grad_scaler is None:
|
self._clip_grad_max_norm = clip_grad_norm
|
||||||
self._scale_one = torch.cuda.FloatTensor([1.0])
|
|
||||||
|
|
||||||
# ======================
|
# get process group
|
||||||
# main parameter stuff
|
def _get_process_group(parallel_mode):
|
||||||
# ======================
|
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA):
|
||||||
|
return gpc.get_group(ParallelMode.DATA)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
# Three groups of parameters:
|
if dp_process_group is None:
|
||||||
# float16_groups: original float16 parameters
|
dp_process_group = _get_process_group(ParallelMode.DATA)
|
||||||
# fp32_from_float16_groups: fp32 copy of float16 parameters
|
if mp_process_group is None:
|
||||||
# fp32_from_fp32_groups: original fp32 parameters
|
mp_process_group = _get_process_group(ParallelMode.MODEL)
|
||||||
self.float16_groups = []
|
|
||||||
self.fp32_from_float16_groups = []
|
self._dp_process_group = dp_process_group
|
||||||
self.fp32_from_fp32_groups = []
|
self._mp_process_group = mp_process_group
|
||||||
|
|
||||||
|
# we maintain three groups of parameters
|
||||||
|
# so that the model can have a mixture
|
||||||
|
# of fp16 and fp32 params
|
||||||
|
# fp16_param_groups: the fp16 params of the model
|
||||||
|
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
|
||||||
|
# fp32_param_groups: the fp32 params of the model
|
||||||
|
# NOTE:
|
||||||
|
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
|
||||||
|
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
|
||||||
|
self._fp16_param_groups = []
|
||||||
|
self._fp32_master_param_groups = []
|
||||||
|
self._fp32_param_groups = []
|
||||||
|
|
||||||
# For all the groups in the original optimizer:
|
# For all the groups in the original optimizer:
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self._optimizer.param_groups:
|
||||||
float16_params_this_group = []
|
fp16_params = []
|
||||||
fp32_params_this_group = []
|
fp32_master_params = []
|
||||||
fp32_from_float16_params_this_group = []
|
fp32_params = []
|
||||||
# For all the parameters in this group:
|
# For all the parameters in this group:
|
||||||
for i, param in enumerate(param_group['params']):
|
for i, param in enumerate(param_group['params']):
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
# float16 params:
|
# float16 params:
|
||||||
if param.type() in ['torch.cuda.HalfTensor',
|
if param.type() in ['torch.cuda.HalfTensor']:
|
||||||
'torch.cuda.BFloat16Tensor']:
|
fp16_params.append(param)
|
||||||
float16_params_this_group.append(param)
|
|
||||||
# Create a copy
|
|
||||||
main_param = param.detach().clone().float()
|
|
||||||
# Copy tensor model parallel attributes.
|
|
||||||
copy_tensor_parallel_attributes(param, main_param)
|
|
||||||
|
|
||||||
# if hasattr(param, 'shared'):
|
# Create a fp32 copy
|
||||||
# main_param.shared = param.shared
|
fp32_param = param.detach().clone().float()
|
||||||
|
# Copy tensor model parallel attributes.
|
||||||
|
copy_tensor_parallel_attributes(param, fp32_param)
|
||||||
|
|
||||||
# Replace the optimizer params with the new fp32 copy.
|
# Replace the optimizer params with the new fp32 copy.
|
||||||
param_group['params'][i] = main_param
|
param_group['params'][i] = fp32_param
|
||||||
fp32_from_float16_params_this_group.append(main_param)
|
fp32_master_params.append(fp32_param)
|
||||||
|
|
||||||
# Reset existing state dict key to the new main param.
|
# Reset existing state dict key to the new main param.
|
||||||
if param in self.optimizer.state:
|
if param in self._optimizer.state:
|
||||||
self.optimizer.state[main_param] \
|
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
|
||||||
= self.optimizer.state.pop(param)
|
|
||||||
|
|
||||||
# fp32 params.
|
# fp32 params.
|
||||||
elif param.type() == 'torch.cuda.FloatTensor':
|
elif param.type() == 'torch.cuda.FloatTensor':
|
||||||
fp32_params_this_group.append(param)
|
fp32_params.append(param)
|
||||||
param_group['params'][i] = param
|
|
||||||
else:
|
else:
|
||||||
raise TypeError('Wrapped parameters must be one of '
|
raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
|
||||||
'torch.cuda.FloatTensor, '
|
f'or torch.cuda.HalfTensor, but got {param.type()}')
|
||||||
'torch.cuda.HalfTensor, or '
|
|
||||||
'torch.cuda.BFloat16Tensor. '
|
|
||||||
'Received {}'.format(param.type()))
|
|
||||||
|
|
||||||
self.float16_groups.append(float16_params_this_group)
|
self._fp16_param_groups.append(fp16_params)
|
||||||
self.fp32_from_float16_groups.append(
|
self._fp32_master_param_groups.append(fp32_master_params)
|
||||||
fp32_from_float16_params_this_group)
|
self._fp32_param_groups.append(fp32_params)
|
||||||
self.fp32_from_fp32_groups.append(fp32_params_this_group)
|
|
||||||
|
|
||||||
# Leverage state_dict() and load_state_dict() to
|
# Leverage state_dict() and load_state_dict() to
|
||||||
# recast preexisting per-param state tensors
|
# recast preexisting per-param state tensors
|
||||||
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
self._optimizer.load_state_dict(self._optimizer.state_dict())
|
||||||
|
|
||||||
def zero_grad(self, set_to_none=False):
|
# log config
|
||||||
"""We only need to zero the model related parameters, i.e.,
|
self._logger = get_dist_logger()
|
||||||
float16_groups & fp32_from_fp32_groups."""
|
if verbose:
|
||||||
for group in self.float16_groups:
|
self._logger.info(
|
||||||
_zero_grad_group_helper(group, set_to_none)
|
f"\n========= FP16 Optimizer Config =========\n"
|
||||||
for group in self.fp32_from_fp32_groups:
|
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||||
_zero_grad_group_helper(group, set_to_none)
|
f"clip_grad_norm = {clip_grad_norm}\n"
|
||||||
|
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
|
||||||
|
f"==========================================",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
def get_loss_scale(self):
|
@property
|
||||||
if self.grad_scaler is None:
|
def grad_scaler(self):
|
||||||
return self._scale_one
|
return self._grad_scaler
|
||||||
return self.grad_scaler.scale
|
|
||||||
|
|
||||||
def _copy_model_grads_to_main_grads(self):
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
return self._grad_scaler.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer(self):
|
||||||
|
return self._optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def defaults(self):
|
||||||
|
return self._defaults
|
||||||
|
|
||||||
|
def _check_overflow(self):
|
||||||
|
# clear previous overflow record
|
||||||
|
self._found_overflow.fill_(0.0)
|
||||||
|
|
||||||
|
# check for overflow
|
||||||
|
for group in self._optimizer.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
if has_inf_or_nan(p.grad):
|
||||||
|
self._found_overflow.fill_(1.0)
|
||||||
|
break
|
||||||
|
|
||||||
|
# all-reduce across dp group
|
||||||
|
if self._dp_process_group:
|
||||||
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
|
||||||
|
|
||||||
|
# all-reduce over model parallel group
|
||||||
|
if self._mp_process_group:
|
||||||
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
|
||||||
|
|
||||||
|
return self._found_overflow.item() > 0
|
||||||
|
|
||||||
|
def zero_grad(self, set_to_none=True):
|
||||||
|
# set_to_none = True can save some memory space
|
||||||
|
for param_group in self._optimizer.param_groups:
|
||||||
|
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
|
||||||
|
|
||||||
|
def _get_fp32_param_groups_to_update(self):
|
||||||
|
return self._fp32_master_param_groups + self._fp32_param_groups
|
||||||
|
|
||||||
|
def _unscale_grads(self):
|
||||||
|
for group in self._get_fp32_param_groups_to_update():
|
||||||
|
for p in group:
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad.data.div_(self.loss_scale)
|
||||||
|
|
||||||
|
def _assign_grad_to_fp32_master_param(self):
|
||||||
# This only needs to be done for the float16 group.
|
# This only needs to be done for the float16 group.
|
||||||
for model_group, main_group in zip(self.float16_groups,
|
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
||||||
self.fp32_from_float16_groups):
|
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
|
||||||
for model_param, main_param in zip(model_group, main_group):
|
fp32_param.grad = fp16_param.grad.float()
|
||||||
if self.params_have_main_grad:
|
# clear unneeded grad on fp16 param
|
||||||
main_param.grad = model_param.main_grad.float()
|
fp16_param.grad = None
|
||||||
else:
|
|
||||||
if model_param.grad is not None:
|
|
||||||
main_param.grad = model_param.grad.float()
|
|
||||||
|
|
||||||
# For fp32 grads, we need to reset the grads to main grad.
|
def _update_fp16_param_from_fp32_param(self):
|
||||||
if self.params_have_main_grad:
|
fp16_param_data = []
|
||||||
for model_group in self.fp32_from_fp32_groups:
|
fp32_master_param_data = []
|
||||||
for model_param in model_group:
|
for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
||||||
model_param.grad = model_param.main_grad
|
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
||||||
|
fp16_param_data.append(fp16_param.data)
|
||||||
def _unscale_main_grads_and_check_for_nan(self):
|
fp32_master_param_data.append(fp32_param.data)
|
||||||
main_grads = []
|
_multi_tensor_copy_this_to_that(this=fp32_master_param_data,
|
||||||
# fp32 params fromm float16 ones.
|
that=fp16_param_data,
|
||||||
for main_group in self.fp32_from_float16_groups:
|
|
||||||
for main_param in main_group:
|
|
||||||
if main_param.grad is not None:
|
|
||||||
main_grads.append(main_param.grad.data)
|
|
||||||
# Append fp32 parameters.
|
|
||||||
for main_group in self.fp32_from_fp32_groups:
|
|
||||||
for main_param in main_group:
|
|
||||||
if main_param.grad is not None:
|
|
||||||
main_grads.append(main_param.grad.data)
|
|
||||||
# Reset found inf.
|
|
||||||
self.found_inf.fill_(0.0)
|
|
||||||
# Unscale and set found inf/nan
|
|
||||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
|
||||||
main_grads, self.found_inf, self.grad_scaler.inv_scale)
|
|
||||||
# Update across all model parallel instances.
|
|
||||||
torch.distributed.all_reduce(self.found_inf,
|
|
||||||
op=torch.distributed.ReduceOp.MAX,
|
|
||||||
group=gpc.get_group(ParallelMode.MODEL))
|
|
||||||
|
|
||||||
# Check for nan.
|
|
||||||
found_inf_flag = (self.found_inf.item() > 0)
|
|
||||||
return found_inf_flag
|
|
||||||
|
|
||||||
def _get_model_and_main_params_data_float16(self):
|
|
||||||
model_data = []
|
|
||||||
main_data = []
|
|
||||||
for model_group, main_group in zip(self.float16_groups,
|
|
||||||
self.fp32_from_float16_groups):
|
|
||||||
for model_param, main_param in zip(model_group, main_group):
|
|
||||||
model_data.append(model_param.data)
|
|
||||||
main_data.append(main_param.data)
|
|
||||||
return model_data, main_data
|
|
||||||
|
|
||||||
def _copy_main_params_to_model_params(self):
|
|
||||||
# Only needed for the float16 params.
|
|
||||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
|
||||||
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
|
|
||||||
overflow_buf=self._dummy_overflow_buf)
|
overflow_buf=self._dummy_overflow_buf)
|
||||||
|
|
||||||
def _copy_model_params_to_main_params(self):
|
|
||||||
# Only needed for the float16 params.
|
|
||||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
|
||||||
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
|
|
||||||
overflow_buf=self._dummy_overflow_buf)
|
|
||||||
|
|
||||||
def reload_model_params(self):
|
|
||||||
self._copy_model_params_to_main_params()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Copy gradients from model params to main params.
|
# Copy gradients from model params to main params.
|
||||||
self._copy_model_grads_to_main_grads()
|
self._assign_grad_to_fp32_master_param()
|
||||||
|
self._unscale_grads()
|
||||||
|
|
||||||
# Do unscale, check for inf, and update grad scaler only for
|
overflow = self._check_overflow()
|
||||||
# the case that grad scaler is provided.
|
self._grad_scaler.update(overflow)
|
||||||
if self.grad_scaler:
|
|
||||||
|
|
||||||
# Unscale and check for inf/nan.
|
if overflow:
|
||||||
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
|
self.zero_grad()
|
||||||
|
return False, None
|
||||||
# We are done with scaling gradients
|
|
||||||
# so we can update the loss scale.
|
|
||||||
self.grad_scaler.update(found_inf_flag)
|
|
||||||
|
|
||||||
# If we found inf/nan, skip the update.
|
|
||||||
if found_inf_flag:
|
|
||||||
return False, None, None
|
|
||||||
|
|
||||||
# Clip the main gradients.
|
# Clip the main gradients.
|
||||||
grad_norm = None
|
grad_norm = None
|
||||||
if self.clip_grad > 0.0:
|
if self._clip_grad_max_norm > 0.0:
|
||||||
grad_norm = self.clip_grad_norm(self.clip_grad)
|
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
|
||||||
|
|
||||||
# count the zeros in the grads
|
|
||||||
num_zeros_in_grad = self.count_zeros() if \
|
|
||||||
self.log_num_zeros_in_grad else None
|
|
||||||
|
|
||||||
# Step the optimizer.
|
# Step the optimizer.
|
||||||
self.optimizer.step()
|
self._optimizer.step()
|
||||||
|
|
||||||
# Update params from main params.
|
# Update params from main params.
|
||||||
self._copy_main_params_to_model_params()
|
self._update_fp16_param_from_fp32_param()
|
||||||
|
|
||||||
# Successful update.
|
# Successful update.
|
||||||
return True, grad_norm, num_zeros_in_grad
|
return True, grad_norm
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
scaled_loss = loss * self.grad_scaler.scale
|
||||||
|
scaled_loss.backward()
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
state_dict['optimizer'] = self.optimizer.state_dict()
|
state_dict['optimizer'] = self._optimizer.state_dict()
|
||||||
if self.grad_scaler:
|
if self.grad_scaler:
|
||||||
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
||||||
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
|
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
# Optimizer.
|
# Optimizer.
|
||||||
optimizer_key = 'optimizer'
|
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||||||
if optimizer_key not in state_dict:
|
|
||||||
optimizer_key = 'optimizer_state_dict'
|
|
||||||
print_rank_0('***WARNING*** loading optimizer from '
|
|
||||||
'an old checkpoint ...')
|
|
||||||
self.optimizer.load_state_dict(state_dict[optimizer_key])
|
|
||||||
|
|
||||||
# Grad scaler.
|
# Grad scaler.
|
||||||
if 'grad_scaler' not in state_dict:
|
if 'grad_scaler' in state_dict:
|
||||||
print_rank_0('***WARNING*** found an old checkpoint, will not '
|
|
||||||
'load grad scaler ...')
|
|
||||||
else:
|
|
||||||
if self.grad_scaler:
|
|
||||||
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
||||||
else:
|
|
||||||
print_rank_0('***WARNING*** fould the grad scaler in the '
|
|
||||||
'checkpoint but it is None in the class. '
|
|
||||||
'Skipping loading grad scaler ...')
|
|
||||||
|
|
||||||
# Copy data for the main params.
|
# Copy data for the main params.
|
||||||
fp32_from_float16_params_key = 'fp32_from_fp16_params'
|
if 'fp32_master_param_groups' in state_dict:
|
||||||
if fp32_from_float16_params_key not in state_dict:
|
for current_group, ckpt_group in zip(self._fp32_master_param_groups,
|
||||||
fp32_from_float16_params_key = 'fp32_from_fp16'
|
state_dict['fp32_master_param_groups']):
|
||||||
for current_group, saved_group in zip(
|
for current_param, ckpt_param in zip(current_group, ckpt_group):
|
||||||
self.fp32_from_float16_groups,
|
current_param.data.copy_(ckpt_param.data)
|
||||||
state_dict[fp32_from_float16_params_key]):
|
|
||||||
for current_param, saved_param in zip(current_group, saved_group):
|
|
||||||
current_param.data.copy_(saved_param.data)
|
|
||||||
|
|
||||||
def get_parameters(self):
|
|
||||||
params = []
|
|
||||||
for param_group in self.optimizer.param_groups:
|
|
||||||
for param in param_group['params']:
|
|
||||||
params.append(param)
|
|
||||||
return params
|
|
||||||
|
|
||||||
def clip_grad_norm(self, clip_grad):
|
def clip_grad_norm(self, clip_grad):
|
||||||
params = self.get_parameters()
|
params = []
|
||||||
|
for param_group in self._optimizer.param_groups:
|
||||||
|
for param in param_group['params']:
|
||||||
|
params.append(param)
|
||||||
return clip_grad_norm_fp32(params, clip_grad)
|
return clip_grad_norm_fp32(params, clip_grad)
|
||||||
|
|
||||||
def count_zeros(self):
|
|
||||||
params = self.get_parameters()
|
|
||||||
return count_zeros_fp32(params)
|
|
||||||
|
|
||||||
def scale_loss(self, loss):
|
|
||||||
"""Simple scaling."""
|
|
||||||
return self.get_loss_scale() * loss
|
|
||||||
|
|
||||||
# Promote state so it can be retrieved or set via
|
# Promote state so it can be retrieved or set via
|
||||||
# "optimizer_instance.state"
|
# "optimizer_instance.state"
|
||||||
def _get_state(self):
|
def _get_state(self):
|
||||||
return self.optimizer.state
|
return self._optimizer.state
|
||||||
|
|
||||||
def _set_state(self, value):
|
def _set_state(self, value):
|
||||||
self.optimizer.state = value
|
self._optimizer.state = value
|
||||||
|
|
||||||
state = property(_get_state, _set_state)
|
state = property(_get_state, _set_state)
|
||||||
|
|
||||||
|
@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer):
|
||||||
# "optimizer_instance.param_groups"
|
# "optimizer_instance.param_groups"
|
||||||
# (for example, to adjust the learning rate)
|
# (for example, to adjust the learning rate)
|
||||||
def _get_param_groups(self):
|
def _get_param_groups(self):
|
||||||
return self.optimizer.param_groups
|
return self._optimizer.param_groups
|
||||||
|
|
||||||
def _set_param_groups(self, value):
|
def _set_param_groups(self, value):
|
||||||
self.optimizer.param_groups = value
|
self._optimizer.param_groups = value
|
||||||
|
|
||||||
param_groups = property(_get_param_groups, _set_param_groups)
|
param_groups = property(_get_param_groups, _set_param_groups)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
from typing import List
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def has_inf_or_nan(tensor):
|
||||||
|
try:
|
||||||
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||||
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
||||||
|
# (which is true for some recent version of pytorch).
|
||||||
|
tensor_sum = float(tensor.float().sum())
|
||||||
|
# More efficient version that can be used if .sum() returns a Python scalar
|
||||||
|
# tensor_sum = float(tensor.sum())
|
||||||
|
except RuntimeError as instance:
|
||||||
|
# We want to check if inst is actually an overflow exception.
|
||||||
|
# RuntimeError could come from a different error.
|
||||||
|
# If so, we still want the exception to propagate.
|
||||||
|
if "value cannot be converted" not in instance.args[0]:
|
||||||
|
raise
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Clear the gradient of a list of tensors,
|
||||||
|
Note: copied from torch.optim.optimizer.
|
||||||
|
"""
|
||||||
|
for param in tensor_list:
|
||||||
|
if param.grad is not None:
|
||||||
|
if set_to_none:
|
||||||
|
param.grad = None
|
||||||
|
else:
|
||||||
|
if param.grad.grad_fn is not None:
|
||||||
|
param.grad.detach_()
|
||||||
|
else:
|
||||||
|
param.grad.requires_grad_(False)
|
||||||
|
param.grad.zero_()
|
|
@ -28,12 +28,10 @@ class BaseGradScaler(ABC):
|
||||||
def inv_scale(self) -> Tensor:
|
def inv_scale(self) -> Tensor:
|
||||||
return self._scale.double().reciprocal().float()
|
return self._scale.double().reciprocal().float()
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def state_dict(self) -> Dict:
|
def state_dict(self) -> Dict:
|
||||||
state_dict = dict()
|
state_dict = dict()
|
||||||
state_dict['scale'] = self.scale
|
state_dict['scale'] = self.scale
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_state_dict(self, state_dict: Dict) -> None:
|
def load_state_dict(self, state_dict: Dict) -> None:
|
||||||
self._scale = state_dict['scale']
|
self._scale = state_dict['scale']
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler):
|
||||||
growth_interval: int = 1000,
|
growth_interval: int = 1000,
|
||||||
min_scale: int = None,
|
min_scale: int = None,
|
||||||
max_scale: int = None,
|
max_scale: int = None,
|
||||||
hysteresis: int = None,
|
hysteresis: int = 2,
|
||||||
verbose: bool = False):
|
verbose: bool = False):
|
||||||
super().__init__(initial_scale, verbose)
|
super().__init__(initial_scale, verbose)
|
||||||
self._min_scale = min_scale
|
if min_scale:
|
||||||
self._max_scale = max_scale
|
self._min_scale = torch.cuda.FloatTensor([min_scale])
|
||||||
|
else:
|
||||||
|
self._min_scale = None
|
||||||
|
|
||||||
|
if max_scale:
|
||||||
|
self._max_scale = torch.cuda.FloatTensor([max_scale])
|
||||||
|
else:
|
||||||
|
self._max_scale = None
|
||||||
|
|
||||||
self._growth_factor = growth_factor
|
self._growth_factor = growth_factor
|
||||||
self._backoff_factor = backoff_factor
|
self._backoff_factor = backoff_factor
|
||||||
self._growth_interval = growth_interval
|
self._growth_interval = growth_interval
|
||||||
|
|
|
@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optim: Optimizer, *args, **kwargs):
|
def __init__(self, optim: Optimizer, *args, **kwargs):
|
||||||
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
|
optim = FP16Optimizer(optim, *args, **kwargs)
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
def backward(self, loss: Tensor):
|
def backward(self, loss: Tensor):
|
||||||
"""Backward with gradient scaler
|
self.optim.backward(loss)
|
||||||
|
|
||||||
:param loss: loss computed by a loss function
|
|
||||||
:type loss: torch.Tensor
|
|
||||||
"""
|
|
||||||
loss = self.optim.scale_loss(loss)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
return self.optim.step()
|
return self.optim.step()
|
||||||
|
|
|
@ -304,7 +304,7 @@ def initialize(model: nn.Module,
|
||||||
if is_using_pp():
|
if is_using_pp():
|
||||||
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||||
if amp_mode == AMP_TYPE.NAIVE:
|
if amp_mode == AMP_TYPE.NAIVE:
|
||||||
cfg_['clip_grad'] = clip_grad_norm
|
cfg_['clip_grad_norm'] = clip_grad_norm
|
||||||
model, optimizer, criterion = convert_to_amp(model=model,
|
model, optimizer, criterion = convert_to_amp(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from itertools import groupby
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -7,7 +6,7 @@ from torch.optim import Optimizer
|
||||||
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
|
from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
|
from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor,
|
||||||
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
|
release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan)
|
||||||
|
@ -16,11 +15,8 @@ from functools import partial
|
||||||
|
|
||||||
class ShardedOptimizer(ColossalaiOptimizer):
|
class ShardedOptimizer(ColossalaiOptimizer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
|
||||||
# grad scaler config
|
|
||||||
initial_scale=2**32,
|
initial_scale=2**32,
|
||||||
min_scale=1,
|
min_scale=1,
|
||||||
growth_factor=2,
|
growth_factor=2,
|
||||||
|
@ -28,23 +24,14 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
||||||
growth_interval=1000,
|
growth_interval=1000,
|
||||||
hysteresis=2,
|
hysteresis=2,
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
|
|
||||||
# grad clipping
|
|
||||||
clip_grad_norm=2.0,
|
clip_grad_norm=2.0,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
|
||||||
# communication
|
|
||||||
reduce_bucket_size=500000000,
|
reduce_bucket_size=500000000,
|
||||||
communication_dtype=torch.float16,
|
communication_dtype=torch.float16,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
|
|
||||||
# stage 2
|
|
||||||
partition_grad=False,
|
partition_grad=False,
|
||||||
|
|
||||||
dp_parallel_mode=ParallelMode.DATA,
|
dp_parallel_mode=ParallelMode.DATA,
|
||||||
mp_parallel_mode=ParallelMode.MODEL,
|
mp_parallel_mode=ParallelMode.MODEL,
|
||||||
|
|
||||||
# cpu offload
|
|
||||||
cpu_offload=False,
|
cpu_offload=False,
|
||||||
cpu_fp16_param=False,
|
cpu_fp16_param=False,
|
||||||
cpu_fp16_grad=False):
|
cpu_fp16_grad=False):
|
||||||
|
@ -263,6 +250,7 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
def reduce_grad_hook(*args):
|
def reduce_grad_hook(*args):
|
||||||
reduction_func()
|
reduction_func()
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer):
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
self._grad_store._averaged_gradients[group_id] = []
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
self._grad_store._averaged_gradients[group_id] = []
|
||||||
|
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import copy
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.amp import convert_to_naive_amp
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
def check_equal(a, b):
|
||||||
|
"""
|
||||||
|
This function checks if two tensors are equal within tolerance
|
||||||
|
"""
|
||||||
|
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
|
||||||
|
|
||||||
|
|
||||||
|
def run_naive_amp():
|
||||||
|
"""
|
||||||
|
In this test, we compare the naive fp16 optimizer implemented in colossalai
|
||||||
|
and fp32 torch optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create layer
|
||||||
|
test_models = ['repeated_computed_layers', 'nested_model']
|
||||||
|
for test_name in test_models:
|
||||||
|
get_component_func = non_distributed_component_funcs.get_callable(test_name)
|
||||||
|
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
|
||||||
|
|
||||||
|
# create model
|
||||||
|
amp_model = model_builder(checkpoint=True).cuda()
|
||||||
|
torch_model = copy.deepcopy(amp_model)
|
||||||
|
|
||||||
|
# create optimizer
|
||||||
|
amp_optimizer = optim_builder(amp_model)
|
||||||
|
torch_optimizer = optim_builder(torch_model)
|
||||||
|
|
||||||
|
# inject naive amp
|
||||||
|
amp_config = dict(initial_scale=1)
|
||||||
|
amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config)
|
||||||
|
|
||||||
|
# create data
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
data, label = next(data_iter)
|
||||||
|
data = data.cuda()
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
amp_output = amp_model(data)
|
||||||
|
torch_output = torch_model(data)
|
||||||
|
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
|
||||||
|
|
||||||
|
# backward
|
||||||
|
amp_optimizer.backward(amp_output.mean())
|
||||||
|
torch_output.mean().backward()
|
||||||
|
|
||||||
|
# check grad
|
||||||
|
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||||
|
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
# step
|
||||||
|
amp_optimizer.step()
|
||||||
|
torch_optimizer.step()
|
||||||
|
|
||||||
|
# check updated param
|
||||||
|
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||||
|
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
run_naive_amp()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_naive_amp():
|
||||||
|
world_size = 1
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_naive_amp()
|
Loading…
Reference in New Issue