From e79ea44247bbb6457cd2a2c30454208017fd31ef Mon Sep 17 00:00:00 2001 From: Frank Lee <somerlee.9@gmail.com> Date: Tue, 15 Mar 2022 10:05:38 +0800 Subject: [PATCH] [fp16] refactored fp16 optimizer (#392) --- colossalai/amp/naive_amp/__init__.py | 21 +- colossalai/amp/naive_amp/_fp16_optimizer.py | 473 +++++++----------- colossalai/amp/naive_amp/_utils.py | 40 ++ .../naive_amp/grad_scaler/base_grad_scaler.py | 2 - .../grad_scaler/dynamic_grad_scaler.py | 14 +- colossalai/amp/naive_amp/naive_amp.py | 10 +- colossalai/initialize.py | 2 +- .../zero/sharded_optim/sharded_optim.py | 69 ++- .../zero/sharded_optim/sharded_optim_v2.py | 2 +- tests/test_amp/test_naive_fp16.py | 83 +++ 10 files changed, 371 insertions(+), 345 deletions(-) create mode 100644 colossalai/amp/naive_amp/_utils.py create mode 100644 tests/test_amp/test_naive_fp16.py diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py index 32ea3469a..2390c199e 100644 --- a/colossalai/amp/naive_amp/__init__.py +++ b/colossalai/amp/naive_amp/__init__.py @@ -1,13 +1,12 @@ +import inspect import torch.nn as nn from torch.optim import Optimizer from colossalai.utils import is_no_pp_or_last_stage - from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel +from .grad_scaler import DynamicGradScaler, ConstantGradScaler -def convert_to_naive_amp(model: nn.Module, - optimizer: Optimizer, - amp_config): +def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): """A helper function to wrap training components with naive AMP modules :param model: your model object @@ -31,7 +30,19 @@ def convert_to_naive_amp(model: nn.Module, output_to_fp32 = is_no_pp_or_last_stage() model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) - optimizer = NaiveAMPOptimizer(optimizer, **amp_config) + use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + if use_dynamic_grad_scaler: + scaler_class = DynamicGradScaler + else: + scaler_class = ConstantGradScaler + + sig = inspect.signature(scaler_class.__init__) + kwargs = dict() + for param in sig.parameters.values(): + if param.name in amp_config: + kwargs[param.name] = amp_config.pop(param.name) + grad_scaler = scaler_class(**kwargs) + optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config) return model, optimizer diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 01842590f..98bb1e639 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import torch +import torch.distributed as dist try: import colossal_C @@ -9,41 +10,30 @@ except: print('Colossalai should be built with cuda extension to use the FP16 optimizer') from torch.optim import Optimizer - -from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode from colossalai.logging import get_dist_logger -from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, - clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) +from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier) +from torch.distributed import ProcessGroup +from .grad_scaler import BaseGradScaler +from ._utils import has_inf_or_nan, zero_gard_by_list - -def _zero_grad_group_helper(group, set_to_none): - """Zero out the gradient for a group of parameters. - Note: copied from torch.optim.optimizer.""" - for param in group: - if param.grad is not None: - if set_to_none: - param.grad = None - else: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() +__all__ = ['FP16Optimizer'] def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): - """Use multi-tensor-applier to copy values from one list to another. + """ + adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) + + Use multi-tensor-applier to copy values from one list to another. We don't have a blfoat16 implementation so for now if the overflow_buf is not provided, we default back to simple loop copy to be compatible - with bfloat16.""" + with bfloat16. + """ if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(colossal_C.multi_tensor_scale, - overflow_buf, - [this, that], - 1.0) + multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): that_.copy_(this_) @@ -111,8 +101,7 @@ class DynamicGradScaler: self._hysteresis_tracker -= 1 # Now if we are out of hysteresis count, scale down the loss. if self._hysteresis_tracker <= 0: - self._scale = torch.max(self._scale * self.backoff_factor, - self.min_scale) + self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) if self.verbose: self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) else: @@ -127,12 +116,13 @@ class DynamicGradScaler: if self._max_scale is not None and self._scale >= self._max_scale: if self.verbose: self._logger.info( - f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) + f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', + ranks=[0]) else: self._scale = self._scale * self.growth_factor if self.verbose: - self._logger.info( - f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) + self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', + ranks=[0]) def state_dict(self): state_dict = {} @@ -173,326 +163,241 @@ class FP16Optimizer(Optimizer): """ def __init__(self, - optimizer, - clip_grad=0, - log_num_zeros_in_grad=False, - initial_scale=2 ** 32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale: int = 2 ** 32, - verbose: bool = False): - # default args for compatibility - bf16 = False - params_have_main_grad = False - + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None): # have a defaults for compatibility with pytorch optim - self.defaults = optimizer.defaults + self._optimizer = optimizer + self._defaults = optimizer.defaults - # log config - self._logger = get_dist_logger() - if verbose: - self._logger.info(f"\n========= FP16 Optimizer Config =========\n" - f"Optimizer: {optimizer.__class__.__name__}\n" - f"clip_grad = {clip_grad}\n" - f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n" - f"initial_scale = {initial_scale}\n" - f"min_scale = {min_scale}\n" - f"growth_factor = {growth_factor}\n" - f"backoff_factor = {backoff_factor}\n" - f"growth_interval = {growth_interval}\n" - f"hysteresis = {hysteresis}\n" - f"==========================================", ranks=[0]) + # fp16-related params + assert isinstance(grad_scaler, BaseGradScaler) + self._grad_scaler = grad_scaler + self._found_overflow = torch.cuda.FloatTensor([0.0]) + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - """Input optimizer is the base optimizer for example Adam.""" - self.optimizer = optimizer - assert self.optimizer, 'no optimizer is provided.' - # Set gradient clipping and logging params. - self.clip_grad = clip_grad - self.log_num_zeros_in_grad = log_num_zeros_in_grad - self.params_have_main_grad = params_have_main_grad + # misc params + self._clip_grad_max_norm = clip_grad_norm - self.bf16 = bf16 - self.grad_scaler = DynamicGradScaler( - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose - ) + # get process group + def _get_process_group(parallel_mode): + if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA): + return gpc.get_group(ParallelMode.DATA) + else: + return None - # None grad scaler is only supported for bf16. - if self.grad_scaler is None: - assert self.bf16, 'fp16 expects a grad scaler.' + if dp_process_group is None: + dp_process_group = _get_process_group(ParallelMode.DATA) + if mp_process_group is None: + mp_process_group = _get_process_group(ParallelMode.MODEL) - # Tensor used to determine if a nan/if has happend. - # Any non-zero value indicates inf/nan. - # Note that we keep this for the cases that grad scaler is none. - # We still record nan/inf if we have a bfloat16 with a grad scaler. - if self.grad_scaler: - self.found_inf = torch.cuda.FloatTensor([0.0]) + self._dp_process_group = dp_process_group + self._mp_process_group = mp_process_group - # Dummy tensor needed for apex multi-apply tensor. - # For bfloat, we don't have multi-tensor apply and for now - # we set it to none so the multi-tensor apply gets ignored. - if bf16: - self._dummy_overflow_buf = None - else: - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - - # In case grad scaler is not passed, define the unity scale. - if self.grad_scaler is None: - self._scale_one = torch.cuda.FloatTensor([1.0]) - - # ====================== - # main parameter stuff - # ====================== - - # Three groups of parameters: - # float16_groups: original float16 parameters - # fp32_from_float16_groups: fp32 copy of float16 parameters - # fp32_from_fp32_groups: original fp32 parameters - self.float16_groups = [] - self.fp32_from_float16_groups = [] - self.fp32_from_fp32_groups = [] + # we maintain three groups of parameters + # so that the model can have a mixture + # of fp16 and fp32 params + # fp16_param_groups: the fp16 params of the model + # fp32_master_param_groups: the fp32 params cast from the fp16 param of the model + # fp32_param_groups: the fp32 params of the model + # NOTE: + # 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence + # 2. fp32_param_groups and fp16_param_groups are exclusive of each other + self._fp16_param_groups = [] + self._fp32_master_param_groups = [] + self._fp32_param_groups = [] # For all the groups in the original optimizer: - for param_group in self.optimizer.param_groups: - float16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_float16_params_this_group = [] + for param_group in self._optimizer.param_groups: + fp16_params = [] + fp32_master_params = [] + fp32_params = [] # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: # float16 params: - if param.type() in ['torch.cuda.HalfTensor', - 'torch.cuda.BFloat16Tensor']: - float16_params_this_group.append(param) - # Create a copy - main_param = param.detach().clone().float() - # Copy tensor model parallel attributes. - copy_tensor_parallel_attributes(param, main_param) + if param.type() in ['torch.cuda.HalfTensor']: + fp16_params.append(param) - # if hasattr(param, 'shared'): - # main_param.shared = param.shared + # Create a fp32 copy + fp32_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + copy_tensor_parallel_attributes(param, fp32_param) # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = main_param - fp32_from_float16_params_this_group.append(main_param) + param_group['params'][i] = fp32_param + fp32_master_params.append(fp32_param) + # Reset existing state dict key to the new main param. - if param in self.optimizer.state: - self.optimizer.state[main_param] \ - = self.optimizer.state.pop(param) + if param in self._optimizer.state: + self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) # fp32 params. elif param.type() == 'torch.cuda.FloatTensor': - fp32_params_this_group.append(param) - param_group['params'][i] = param + fp32_params.append(param) else: - raise TypeError('Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(param.type())) + raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' + f'or torch.cuda.HalfTensor, but got {param.type()}') - self.float16_groups.append(float16_params_this_group) - self.fp32_from_float16_groups.append( - fp32_from_float16_params_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) + self._fp16_param_groups.append(fp16_params) + self._fp32_master_param_groups.append(fp32_master_params) + self._fp32_param_groups.append(fp32_params) # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors - self.optimizer.load_state_dict(self.optimizer.state_dict()) + self._optimizer.load_state_dict(self._optimizer.state_dict()) - def zero_grad(self, set_to_none=False): - """We only need to zero the model related parameters, i.e., - float16_groups & fp32_from_fp32_groups.""" - for group in self.float16_groups: - _zero_grad_group_helper(group, set_to_none) - for group in self.fp32_from_fp32_groups: - _zero_grad_group_helper(group, set_to_none) + # log config + self._logger = get_dist_logger() + if verbose: + self._logger.info( + f"\n========= FP16 Optimizer Config =========\n" + f"Optimizer: {optimizer.__class__.__name__}\n" + f"clip_grad_norm = {clip_grad_norm}\n" + f"grad_scaler = {self._grad_scaler.__class__.__name__}" + f"==========================================", + ranks=[0]) - def get_loss_scale(self): - if self.grad_scaler is None: - return self._scale_one - return self.grad_scaler.scale + @property + def grad_scaler(self): + return self._grad_scaler - def _copy_model_grads_to_main_grads(self): + @property + def loss_scale(self): + return self._grad_scaler.scale + + @property + def optimizer(self): + return self._optimizer + + @property + def defaults(self): + return self._defaults + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(0.0) + + # check for overflow + for group in self._optimizer.param_groups: + for p in group['params']: + if has_inf_or_nan(p.grad): + self._found_overflow.fill_(1.0) + break + + # all-reduce across dp group + if self._dp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group) + + # all-reduce over model parallel group + if self._mp_process_group: + dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group) + + return self._found_overflow.item() > 0 + + def zero_grad(self, set_to_none=True): + # set_to_none = True can save some memory space + for param_group in self._optimizer.param_groups: + zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + + def _get_fp32_param_groups_to_update(self): + return self._fp32_master_param_groups + self._fp32_param_groups + + def _unscale_grads(self): + for group in self._get_fp32_param_groups_to_update(): + for p in group: + if p.grad is not None: + p.grad.data.div_(self.loss_scale) + + def _assign_grad_to_fp32_master_param(self): # This only needs to be done for the float16 group. - for model_group, main_group in zip(self.float16_groups, - self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - if self.params_have_main_grad: - main_param.grad = model_param.main_grad.float() - else: - if model_param.grad is not None: - main_param.grad = model_param.grad.float() + for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group): + fp32_param.grad = fp16_param.grad.float() + # clear unneeded grad on fp16 param + fp16_param.grad = None - # For fp32 grads, we need to reset the grads to main grad. - if self.params_have_main_grad: - for model_group in self.fp32_from_fp32_groups: - for model_param in model_group: - model_param.grad = model_param.main_grad - - def _unscale_main_grads_and_check_for_nan(self): - main_grads = [] - # fp32 params fromm float16 ones. - for main_group in self.fp32_from_float16_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - # Append fp32 parameters. - for main_group in self.fp32_from_fp32_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - # Reset found inf. - self.found_inf.fill_(0.0) - # Unscale and set found inf/nan - torch._amp_foreach_non_finite_check_and_unscale_( - main_grads, self.found_inf, self.grad_scaler.inv_scale) - # Update across all model parallel instances. - torch.distributed.all_reduce(self.found_inf, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL)) - - # Check for nan. - found_inf_flag = (self.found_inf.item() > 0) - return found_inf_flag - - def _get_model_and_main_params_data_float16(self): - model_data = [] - main_data = [] - for model_group, main_group in zip(self.float16_groups, - self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - model_data.append(model_param.data) - main_data.append(main_param.data) - return model_data, main_data - - def _copy_main_params_to_model_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that(this=main_data, that=model_data, + def _update_fp16_param_from_fp32_param(self): + fp16_param_data = [] + fp32_master_param_data = [] + for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + fp16_param_data.append(fp16_param.data) + fp32_master_param_data.append(fp32_param.data) + _multi_tensor_copy_this_to_that(this=fp32_master_param_data, + that=fp16_param_data, overflow_buf=self._dummy_overflow_buf) - def _copy_model_params_to_main_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that(this=model_data, that=main_data, - overflow_buf=self._dummy_overflow_buf) - - def reload_model_params(self): - self._copy_model_params_to_main_params() - - @torch.no_grad() def step(self): # Copy gradients from model params to main params. - self._copy_model_grads_to_main_grads() + self._assign_grad_to_fp32_master_param() + self._unscale_grads() - # Do unscale, check for inf, and update grad scaler only for - # the case that grad scaler is provided. - if self.grad_scaler: + overflow = self._check_overflow() + self._grad_scaler.update(overflow) - # Unscale and check for inf/nan. - found_inf_flag = self._unscale_main_grads_and_check_for_nan() - - # We are done with scaling gradients - # so we can update the loss scale. - self.grad_scaler.update(found_inf_flag) - - # If we found inf/nan, skip the update. - if found_inf_flag: - return False, None, None + if overflow: + self.zero_grad() + return False, None # Clip the main gradients. grad_norm = None - if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad) - - # count the zeros in the grads - num_zeros_in_grad = self.count_zeros() if \ - self.log_num_zeros_in_grad else None + if self._clip_grad_max_norm > 0.0: + grad_norm = self.clip_grad_norm(self._clip_grad_max_norm) # Step the optimizer. - self.optimizer.step() + self._optimizer.step() # Update params from main params. - self._copy_main_params_to_model_params() + self._update_fp16_param_from_fp32_param() # Successful update. - return True, grad_norm, num_zeros_in_grad + return True, grad_norm + + def backward(self, loss): + scaled_loss = loss * self.grad_scaler.scale + scaled_loss.backward() def state_dict(self): state_dict = {} - state_dict['optimizer'] = self.optimizer.state_dict() + state_dict['optimizer'] = self._optimizer.state_dict() if self.grad_scaler: state_dict['grad_scaler'] = self.grad_scaler.state_dict() - state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups + state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups return state_dict def load_state_dict(self, state_dict): # Optimizer. - optimizer_key = 'optimizer' - if optimizer_key not in state_dict: - optimizer_key = 'optimizer_state_dict' - print_rank_0('***WARNING*** loading optimizer from ' - 'an old checkpoint ...') - self.optimizer.load_state_dict(state_dict[optimizer_key]) + self._optimizer.load_state_dict(state_dict['optimizer']) # Grad scaler. - if 'grad_scaler' not in state_dict: - print_rank_0('***WARNING*** found an old checkpoint, will not ' - 'load grad scaler ...') - else: - if self.grad_scaler: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) - else: - print_rank_0('***WARNING*** fould the grad scaler in the ' - 'checkpoint but it is None in the class. ' - 'Skipping loading grad scaler ...') + if 'grad_scaler' in state_dict: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) # Copy data for the main params. - fp32_from_float16_params_key = 'fp32_from_fp16_params' - if fp32_from_float16_params_key not in state_dict: - fp32_from_float16_params_key = 'fp32_from_fp16' - for current_group, saved_group in zip( - self.fp32_from_float16_groups, - state_dict[fp32_from_float16_params_key]): - for current_param, saved_param in zip(current_group, saved_group): - current_param.data.copy_(saved_param.data) - - def get_parameters(self): - params = [] - for param_group in self.optimizer.param_groups: - for param in param_group['params']: - params.append(param) - return params + if 'fp32_master_param_groups' in state_dict: + for current_group, ckpt_group in zip(self._fp32_master_param_groups, + state_dict['fp32_master_param_groups']): + for current_param, ckpt_param in zip(current_group, ckpt_group): + current_param.data.copy_(ckpt_param.data) def clip_grad_norm(self, clip_grad): - params = self.get_parameters() + params = [] + for param_group in self._optimizer.param_groups: + for param in param_group['params']: + params.append(param) return clip_grad_norm_fp32(params, clip_grad) - def count_zeros(self): - params = self.get_parameters() - return count_zeros_fp32(params) - - def scale_loss(self, loss): - """Simple scaling.""" - return self.get_loss_scale() * loss - # Promote state so it can be retrieved or set via # "optimizer_instance.state" def _get_state(self): - return self.optimizer.state + return self._optimizer.state def _set_state(self, value): - self.optimizer.state = value + self._optimizer.state = value state = property(_get_state, _set_state) @@ -500,9 +405,9 @@ class FP16Optimizer(Optimizer): # "optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): - return self.optimizer.param_groups + return self._optimizer.param_groups def _set_param_groups(self, value): - self.optimizer.param_groups = value + self._optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py new file mode 100644 index 000000000..5d87135a8 --- /dev/null +++ b/colossalai/amp/naive_amp/_utils.py @@ -0,0 +1,40 @@ +from typing import List +from torch import Tensor + + +def has_inf_or_nan(tensor): + try: + # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as tensor + # (which is true for some recent version of pytorch). + tensor_sum = float(tensor.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # tensor_sum = float(tensor.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + return True + return False + + +def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None: + """ + Clear the gradient of a list of tensors, + Note: copied from torch.optim.optimizer. + """ + for param in tensor_list: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index fb279baf6..2d3e3700d 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -28,12 +28,10 @@ class BaseGradScaler(ABC): def inv_scale(self) -> Tensor: return self._scale.double().reciprocal().float() - @abstractmethod def state_dict(self) -> Dict: state_dict = dict() state_dict['scale'] = self.scale - @abstractmethod def load_state_dict(self, state_dict: Dict) -> None: self._scale = state_dict['scale'] diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 79fd0f3a3..49f155f06 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -16,11 +16,19 @@ class DynamicGradScaler(BaseGradScaler): growth_interval: int = 1000, min_scale: int = None, max_scale: int = None, - hysteresis: int = None, + hysteresis: int = 2, verbose: bool = False): super().__init__(initial_scale, verbose) - self._min_scale = min_scale - self._max_scale = max_scale + if min_scale: + self._min_scale = torch.cuda.FloatTensor([min_scale]) + else: + self._min_scale = None + + if max_scale: + self._max_scale = torch.cuda.FloatTensor([max_scale]) + else: + self._max_scale = None + self._growth_factor = growth_factor self._backoff_factor = backoff_factor self._growth_interval = growth_interval diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index c4e950f68..1ee34931f 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -26,17 +26,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): """ def __init__(self, optim: Optimizer, *args, **kwargs): - optim = FP16Optimizer(optimizer=optim, *args, **kwargs) + optim = FP16Optimizer(optim, *args, **kwargs) super().__init__(optim) def backward(self, loss: Tensor): - """Backward with gradient scaler - - :param loss: loss computed by a loss function - :type loss: torch.Tensor - """ - loss = self.optim.scale_loss(loss) - loss.backward() + self.optim.backward(loss) def step(self): return self.optim.step() diff --git a/colossalai/initialize.py b/colossalai/initialize.py index d87f9658b..011859881 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -304,7 +304,7 @@ def initialize(model: nn.Module, if is_using_pp(): assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' if amp_mode == AMP_TYPE.NAIVE: - cfg_['clip_grad'] = clip_grad_norm + cfg_['clip_grad_norm'] = clip_grad_norm model, optimizer, criterion = convert_to_amp(model=model, optimizer=optimizer, criterion=criterion, diff --git a/colossalai/zero/sharded_optim/sharded_optim.py b/colossalai/zero/sharded_optim/sharded_optim.py index 9dff355db..2ea2feaf6 100644 --- a/colossalai/zero/sharded_optim/sharded_optim.py +++ b/colossalai/zero/sharded_optim/sharded_optim.py @@ -1,4 +1,3 @@ -from itertools import groupby from colossalai.utils.cuda import get_current_device import torch import torch.distributed as dist @@ -7,7 +6,7 @@ from torch.optim import Optimizer from .bookkeeping import ParameterStore, GradientStore, BucketStore, TensorBucket from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.nn.optimizer import ColossalaiOptimizer from ._utils import (move_tensor, flatten, get_grad_accumulate_object, split_half_float_double, reduce_tensor, release_param_grad, calculate_global_norm_from_list, compute_norm, sync_param, has_inf_or_nan) @@ -16,38 +15,26 @@ from functools import partial class ShardedOptimizer(ColossalaiOptimizer): - def __init__( - self, - optimizer: Optimizer, - - # grad scaler config - initial_scale=2**32, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale: int = 2**32, - - # grad clipping - clip_grad_norm=2.0, - verbose=False, - - # communication - reduce_bucket_size=500000000, - communication_dtype=torch.float16, - overlap_communication=False, - - # stage 2 - partition_grad=False, - - dp_parallel_mode=ParallelMode.DATA, - mp_parallel_mode=ParallelMode.MODEL, - - # cpu offload - cpu_offload=False, - cpu_fp16_param=False, - cpu_fp16_grad=False): + def __init__(self, + optimizer: Optimizer, + initial_scale=2**32, + min_scale=1, + growth_factor=2, + backoff_factor=0.5, + growth_interval=1000, + hysteresis=2, + max_scale: int = 2**32, + clip_grad_norm=2.0, + verbose=False, + reduce_bucket_size=500000000, + communication_dtype=torch.float16, + overlap_communication=False, + partition_grad=False, + dp_parallel_mode=ParallelMode.DATA, + mp_parallel_mode=ParallelMode.MODEL, + cpu_offload=False, + cpu_fp16_param=False, + cpu_fp16_grad=False): # TODO: add support for # 1. fp16 master weights @@ -257,12 +244,13 @@ class ShardedOptimizer(ColossalaiOptimizer): reduction_func = partial(self._reduce_and_remove_grads_by_bucket, param=param, reduce_rank=reduce_rank) - + # define hook # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): reduction_func() + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank) @@ -293,8 +281,8 @@ class ShardedOptimizer(ColossalaiOptimizer): def _reduce_grads_in_bucket(self, reduce_rank=None): # reduce grads self._reduce_grads_by_rank(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) # use communication stream if overlapping # communication with computation @@ -323,7 +311,7 @@ class ShardedOptimizer(ColossalaiOptimizer): # we do not keep the gradient after reduction if self._partition_grads and not self._param_store.belongs_to_current_rank(param): if self._overlap_communication: - # we need to keep this gradient for now as reduction may + # we need to keep this gradient for now as reduction may # be completed yet since it is using a different cuda stream self._param_store.add_previous_reduced_param(param) else: @@ -444,7 +432,6 @@ class ShardedOptimizer(ColossalaiOptimizer): self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = [] - # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) @@ -501,7 +488,7 @@ class ShardedOptimizer(ColossalaiOptimizer): def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group combined_scale = self.loss_scale - + if self._clip_grad_norm > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm @@ -562,7 +549,7 @@ class ShardedOptimizer(ColossalaiOptimizer): for param in param_group: if param.grad is not None: self._reduce_and_remove_grads_by_bucket(param) - + # we need to reduce the gradients # left in the communication bucket self._reduce_grads_in_bucket() diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index d78ac3ecc..14b670a88 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist import torch.nn as nn -from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.optimizer import ColossalaiOptimizer diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py new file mode 100644 index 000000000..c777d1587 --- /dev/null +++ b/tests/test_amp/test_naive_fp16.py @@ -0,0 +1,83 @@ +import torch +import colossalai +import copy +import pytest +import torch.multiprocessing as mp +from colossalai.amp import convert_to_naive_amp +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.utils import free_port +from functools import partial + + +def check_equal(a, b): + """ + This function checks if two tensors are equal within tolerance + """ + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + + +def run_naive_amp(): + """ + In this test, we compare the naive fp16 optimizer implemented in colossalai + and fp32 torch optimizer + """ + + # create layer + test_models = ['repeated_computed_layers', 'nested_model'] + for test_name in test_models: + get_component_func = non_distributed_component_funcs.get_callable(test_name) + model_builder, train_dataloader, _, optim_builder, _ = get_component_func() + + # create model + amp_model = model_builder(checkpoint=True).cuda() + torch_model = copy.deepcopy(amp_model) + + # create optimizer + amp_optimizer = optim_builder(amp_model) + torch_optimizer = optim_builder(torch_model) + + # inject naive amp + amp_config = dict(initial_scale=1) + amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config) + + # create data + data_iter = iter(train_dataloader) + data, label = next(data_iter) + data = data.cuda() + + # forward pass + amp_output = amp_model(data) + torch_output = torch_model(data) + assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}' + + # backward + amp_optimizer.backward(amp_output.mean()) + torch_output.mean().backward() + + # check grad + for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): + torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3) + + # step + amp_optimizer.step() + torch_optimizer.step() + + # check updated param + for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): + torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + run_naive_amp() + + +@pytest.mark.dist +def test_naive_amp(): + world_size = 1 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_naive_amp()