#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch try: import colossal_C except: print('Colossalai should be built with cuda extension to use the FP16 optimizer') from torch.optim import Optimizer from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier) def _zero_grad_group_helper(group, set_to_none): """Zero out the gradient for a group of parameters. Note: copied from torch.optim.optimizer.""" for param in group: if param.grad is not None: if set_to_none: param.grad = None else: if param.grad.grad_fn is not None: param.grad.detach_() else: param.grad.requires_grad_(False) param.grad.zero_() def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): """Use multi-tensor-applier to copy values from one list to another. We don't have a blfoat16 implementation so for now if the overflow_buf is not provided, we default back to simple loop copy to be compatible with bfloat16.""" if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): that_.copy_(this_) class DynamicGradScaler: def __init__(self, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale: int = None): """"Grad scaler with dynamic scale that gets adjusted during training.""" assert initial_scale > 0.0 self._scale = torch.cuda.FloatTensor([initial_scale]) # Lower bound on the scale. assert min_scale > 0.0 assert min_scale <= initial_scale self.min_scale = torch.cuda.FloatTensor([min_scale]) # Growth and backoff factors for the scale. assert growth_factor > 1.0 self.growth_factor = torch.cuda.FloatTensor([growth_factor]) assert backoff_factor < 1.0 assert backoff_factor > 0.0 self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) # Interval over which if we don't see any inf/nan, # we will scale the grad scale by the growth factor. assert growth_interval > 0 self.growth_interval = growth_interval # Number of inf/nans we should see before scaling down # the grad scale by the backoff factor. assert hysteresis > 0 self.hysteresis = hysteresis if max_scale is not None: assert max_scale > 1 and initial_scale <= max_scale self._max_scale = max_scale # Trackers. self._growth_tracker = 0 self._hysteresis_tracker = self.hysteresis self._logger = get_dist_logger() @property def scale(self): return self._scale @property def inv_scale(self): return self._scale.double().reciprocal().float() def update(self, found_inf): # If we have an inf/nan, growth tracker is set to 0 # and hysterisis tracker is reduced by 1. if found_inf: self._growth_tracker = 0 self._hysteresis_tracker -= 1 # Now if we are out of hysteresis count, scale down the loss. if self._hysteresis_tracker <= 0: self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale) self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0]) else: # If there is no nan/inf, increment the growth tracker. self._growth_tracker += 1 # If we have had enough consequitive intervals with no nan/inf: if self._growth_tracker == self.growth_interval: # Reset the tracker and hysteresis trackers, self._growth_tracker = 0 self._hysteresis_tracker = self.hysteresis # and scale up the loss scale. if self._max_scale is not None and self._scale >= self._max_scale: self._logger.info( f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0]) else: self._scale = self._scale * self.growth_factor self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0]) def state_dict(self): state_dict = {} state_dict['max_scale'] = self._max_scale state_dict['scale'] = self._scale state_dict['growth_tracker'] = self._growth_tracker state_dict['hysteresis_tracker'] = self._hysteresis_tracker return state_dict def load_state_dict(self, state_dict): self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) self._growth_tracker = state_dict['growth_tracker'] self._hysteresis_tracker = state_dict['hysteresis_tracker'] self._max_scale = state_dict['max_scale'] class FP16Optimizer(Optimizer): """Float16 optimizer for fp16 and bf16 data types. Arguments: optimizer: base optimizer such as Adam or SGD clip_grad: clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0 log_num_zeros_in_grad: return number of zeros in the gradients. params_have_main_grad: flag indicating if parameters have a `main_grad` field. If this is set, we are assuming that the model parameters are store in the `main_grad` field instead of the typical `grad` field. This happens for the DDP cases where there is a contihuous buffer holding the gradients. For example for bfloat16, we want to do gradient accumulation and all-reduces in float32 and as a result we store those gradients in the main_grad. Note that main grad is not necessarily in float32. bf16: if true, the model is running in bfloat16. grad_scaler: used for scaling gradients. Note that this can be None. This case happens when `bf16 = True` and we don't use any loss scale. Note that for `bf16 = True`, we can have a constnat gradient scaler. Also for `bf16 = False`, we always require a grad scaler. """ def __init__(self, optimizer, clip_grad=0, log_num_zeros_in_grad=False, initial_scale=2 ** 32, min_scale=1, growth_factor=2, backoff_factor=0.5, growth_interval=1000, hysteresis=2, max_scale: int = 2 ** 32): # default args for compatibility bf16 = False params_have_main_grad = True # have a defaults for compatibility with pytorch optim self.defaults = optimizer.defaults # log config self._logger = get_dist_logger() self._logger.info(f"\n========= FP16 Optimizer Config =========\n" f"Optimizer: {optimizer.__class__.__name__}\n" f"clip_grad = {clip_grad}\n" f"log_num_zeros_in_grad = {log_num_zeros_in_grad}\n" f"initial_scale = {initial_scale}\n" f"min_scale = {min_scale}\n" f"growth_factor = {growth_factor}\n" f"backoff_factor = {backoff_factor}\n" f"growth_interval = {growth_interval}\n" f"hysteresis = {hysteresis}\n" f"==========================================", ranks=[0]) """Input optimizer is the base optimizer for example Adam.""" self.optimizer = optimizer assert self.optimizer, 'no optimizer is provided.' # Set gradient clipping and logging params. self.clip_grad = clip_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad self.params_have_main_grad = params_have_main_grad self.bf16 = bf16 self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale ) # None grad scaler is only supported for bf16. if self.grad_scaler is None: assert self.bf16, 'fp16 expects a grad scaler.' # Tensor used to determine if a nan/if has happend. # Any non-zero value indicates inf/nan. # Note that we keep this for the cases that grad scaler is none. # We still record nan/inf if we have a bfloat16 with a grad scaler. if self.grad_scaler: self.found_inf = torch.cuda.FloatTensor([0.0]) # Dummy tensor needed for apex multi-apply tensor. # For bfloat, we don't have multi-tensor apply and for now # we set it to none so the multi-tensor apply gets ignored. if bf16: self._dummy_overflow_buf = None else: self._dummy_overflow_buf = torch.cuda.IntTensor([0]) # In case grad scaler is not passed, define the unity scale. if self.grad_scaler is None: self._scale_one = torch.cuda.FloatTensor([1.0]) # ====================== # main parameter stuff # ====================== # Three groups of parameters: # float16_groups: original float16 parameters # fp32_from_float16_groups: fp32 copy of float16 parameters # fp32_from_fp32_groups: original fp32 parameters self.float16_groups = [] self.fp32_from_float16_groups = [] self.fp32_from_fp32_groups = [] # For all the groups in the original optimizer: for param_group in self.optimizer.param_groups: float16_params_this_group = [] fp32_params_this_group = [] fp32_from_float16_params_this_group = [] # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: # float16 params: if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: float16_params_this_group.append(param) # Create a copy main_param = param.detach().clone().float() # Copy tensor model parallel attributes. copy_tensor_parallel_attributes(param, main_param) # if hasattr(param, 'shared'): # main_param.shared = param.shared # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = main_param fp32_from_float16_params_this_group.append(main_param) # Reset existing state dict key to the new main param. if param in self.optimizer.state: self.optimizer.state[main_param] \ = self.optimizer.state.pop(param) # fp32 params. elif param.type() == 'torch.cuda.FloatTensor': fp32_params_this_group.append(param) param_group['params'][i] = param else: raise TypeError('Wrapped parameters must be one of ' 'torch.cuda.FloatTensor, ' 'torch.cuda.HalfTensor, or ' 'torch.cuda.BFloat16Tensor. ' 'Received {}'.format(param.type())) self.float16_groups.append(float16_params_this_group) self.fp32_from_float16_groups.append( fp32_from_float16_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group) # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors self.optimizer.load_state_dict(self.optimizer.state_dict()) def zero_grad(self, set_to_none=False): """We only need to zero the model related parameters, i.e., float16_groups & fp32_from_fp32_groups.""" for group in self.float16_groups: _zero_grad_group_helper(group, set_to_none) for group in self.fp32_from_fp32_groups: _zero_grad_group_helper(group, set_to_none) def get_loss_scale(self): if self.grad_scaler is None: return self._scale_one return self.grad_scaler.scale def _copy_model_grads_to_main_grads(self): # This only needs to be done for the float16 group. for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): for model_param, main_param in zip(model_group, main_group): if self.params_have_main_grad: main_param.grad = model_param.main_grad.float() else: if model_param.grad is not None: main_param.grad = model_param.grad.float() # For fp32 grads, we need to reset the grads to main grad. if self.params_have_main_grad: for model_group in self.fp32_from_fp32_groups: for model_param in model_group: model_param.grad = model_param.main_grad def _unscale_main_grads_and_check_for_nan(self): main_grads = [] # fp32 params fromm float16 ones. for main_group in self.fp32_from_float16_groups: for main_param in main_group: if main_param.grad is not None: main_grads.append(main_param.grad.data) # Append fp32 parameters. for main_group in self.fp32_from_fp32_groups: for main_param in main_group: if main_param.grad is not None: main_grads.append(main_param.grad.data) # Reset found inf. self.found_inf.fill_(0.0) # Unscale and set found inf/nan torch._amp_foreach_non_finite_check_and_unscale_( main_grads, self.found_inf, self.grad_scaler.inv_scale) # Update across all model parallel instances. torch.distributed.all_reduce(self.found_inf, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.TENSOR)) # Check for nan. found_inf_flag = (self.found_inf.item() > 0) return found_inf_flag def _get_model_and_main_params_data_float16(self): model_data = [] main_data = [] for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): for model_param, main_param in zip(model_group, main_group): model_data.append(model_param.data) main_data.append(main_param.data) return model_data, main_data def _copy_main_params_to_model_params(self): # Only needed for the float16 params. model_data, main_data = self._get_model_and_main_params_data_float16() _multi_tensor_copy_this_to_that(this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf) def _copy_model_params_to_main_params(self): # Only needed for the float16 params. model_data, main_data = self._get_model_and_main_params_data_float16() _multi_tensor_copy_this_to_that(this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf) def reload_model_params(self): self._copy_model_params_to_main_params() @torch.no_grad() def step(self): # Copy gradients from model params to main params. self._copy_model_grads_to_main_grads() # Do unscale, check for inf, and update grad scaler only for # the case that grad scaler is provided. if self.grad_scaler: # Unscale and check for inf/nan. found_inf_flag = self._unscale_main_grads_and_check_for_nan() # We are done with scaling gradients # so we can update the loss scale. self.grad_scaler.update(found_inf_flag) # If we found inf/nan, skip the update. if found_inf_flag: return False, None, None # Clip the main gradients. grad_norm = None if self.clip_grad > 0.0: grad_norm = self.clip_grad_norm(self.clip_grad) # count the zeros in the grads num_zeros_in_grad = self.count_zeros() if \ self.log_num_zeros_in_grad else None # Step the optimizer. self.optimizer.step() # Update params from main params. self._copy_main_params_to_model_params() # Successful update. return True, grad_norm, num_zeros_in_grad def state_dict(self): state_dict = {} state_dict['optimizer'] = self.optimizer.state_dict() if self.grad_scaler: state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups return state_dict def load_state_dict(self, state_dict): # Optimizer. optimizer_key = 'optimizer' if optimizer_key not in state_dict: optimizer_key = 'optimizer_state_dict' print_rank_0('***WARNING*** loading optimizer from ' 'an old checkpoint ...') self.optimizer.load_state_dict(state_dict[optimizer_key]) # Grad scaler. if 'grad_scaler' not in state_dict: print_rank_0('***WARNING*** found an old checkpoint, will not ' 'load grad scaler ...') else: if self.grad_scaler: self.grad_scaler.load_state_dict(state_dict['grad_scaler']) else: print_rank_0('***WARNING*** fould the grad scaler in the ' 'checkpoint but it is None in the class. ' 'Skipping loading grad scaler ...') # Copy data for the main params. fp32_from_float16_params_key = 'fp32_from_fp16_params' if fp32_from_float16_params_key not in state_dict: fp32_from_float16_params_key = 'fp32_from_fp16' for current_group, saved_group in zip( self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]): for current_param, saved_param in zip(current_group, saved_group): current_param.data.copy_(saved_param.data) def get_parameters(self): params = [] for param_group in self.optimizer.param_groups: for param in param_group['params']: params.append(param) return params def clip_grad_norm(self, clip_grad): params = self.get_parameters() return clip_grad_norm_fp32(params, clip_grad) def count_zeros(self): params = self.get_parameters() return count_zeros_fp32(params) def scale_loss(self, loss): """Simple scaling.""" return self.get_loss_scale() * loss # Promote state so it can be retrieved or set via # "optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via # "optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups)