#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.kernel import fused_optim from colossalai.logging import get_dist_logger from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier from ._utils import has_inf_or_nan, zero_gard_by_list from .grad_scaler import BaseGradScaler __all__ = ['FP16Optimizer'] def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): """ adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) Use multi-tensor-applier to copy values from one list to another. We don't have a blfoat16 implementation so for now if the overflow_buf is not provided, we default back to simple loop copy to be compatible with bfloat16. """ if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): that_.copy_(this_) class FP16Optimizer(Optimizer): """Float16 optimizer for fp16 and bf16 data types. Args: optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD grad_scaler (BaseGradScaler): grad scaler for gradient chose in ``constant_grad_scaler`` or ``dynamic_grad_scaler``. clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. Note that clipping is ignored if clip_grad == 0 verbose (bool, optional): if set to `True`, will print debug info. Default False. """ def __init__(self, optimizer: Optimizer, grad_scaler: BaseGradScaler, verbose: bool = False, clip_grad_norm=0, dp_process_group: ProcessGroup = None, mp_process_group: ProcessGroup = None): # have a defaults for compatibility with pytorch optim self._optimizer = optimizer self._defaults = optimizer.defaults # fp16-related params assert isinstance(grad_scaler, BaseGradScaler) self._grad_scaler = grad_scaler self._found_overflow = torch.cuda.FloatTensor([0.0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0]) # misc params self._clip_grad_max_norm = clip_grad_norm # get process group def _get_process_group(parallel_mode): if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode): return gpc.get_group(parallel_mode) else: return None if dp_process_group is None: dp_process_group = _get_process_group(ParallelMode.DATA) if mp_process_group is None: mp_process_group = _get_process_group(ParallelMode.MODEL) self._dp_process_group = dp_process_group self._mp_process_group = mp_process_group # we maintain three groups of parameters # so that the model can have a mixture # of fp16 and fp32 params # fp16_param_groups: the fp16 params of the model # fp32_master_param_groups: the fp32 params cast from the fp16 param of the model # fp32_param_groups: the fp32 params of the model # NOTE: # 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence # 2. fp32_param_groups and fp16_param_groups are exclusive of each other self._fp16_param_groups = [] self._fp32_master_param_groups = [] self._fp32_param_groups = [] # For all the groups in the original optimizer: for param_group in self._optimizer.param_groups: fp16_params = [] fp32_master_params = [] fp32_params = [] # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: # float16 params: if param.type() in ['torch.cuda.HalfTensor']: fp16_params.append(param) # Create a fp32 copy fp32_param = param.detach().clone().float() # Copy tensor model parallel attributes. copy_tensor_parallel_attributes(param, fp32_param) # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = fp32_param fp32_master_params.append(fp32_param) # Reset existing state dict key to the new main param. if param in self._optimizer.state: self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) # fp32 params. elif param.type() == 'torch.cuda.FloatTensor': fp32_params.append(param) else: raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' f'or torch.cuda.HalfTensor, but got {param.type()}') self._fp16_param_groups.append(fp16_params) self._fp32_master_param_groups.append(fp32_master_params) self._fp32_param_groups.append(fp32_params) # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors self._optimizer.load_state_dict(self._optimizer.state_dict()) # log config self._logger = get_dist_logger() if verbose: self._logger.info( f"\n========= FP16 Optimizer Config =========\n" f"Optimizer: {optimizer.__class__.__name__}\n" f"clip_grad_norm = {clip_grad_norm}\n" f"grad_scaler = {self._grad_scaler.__class__.__name__}" f"==========================================", ranks=[0]) @property def max_norm(self): """Returns the maximum norm of gradient clipping. """ return self._clip_grad_max_norm @property def grad_scaler(self): """Returns the gradient scaler. Returns: :class:`BaseGradScaler`: gradient scaler. """ return self._grad_scaler @property def loss_scale(self): """Returns the loss scale. Returns: int: loss scale. """ return self._grad_scaler.scale @property def optimizer(self): """Returns the optimizer. Returns: :class:`torch.optim.Optimizer`: the optimizer object wrapped. """ return self._optimizer @property def defaults(self): """Returns the default arguments of optimizer. Returns: dict: optimizer arguments saved in defaults of the optimizer wrapped. """ return self._defaults def _check_overflow(self): # clear previous overflow record self._found_overflow.fill_(0.0) # check for overflow for group in self._optimizer.param_groups: for p in group['params']: if p.grad is not None and has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) break # all-reduce across dp group if self._dp_process_group: dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group) # all-reduce over model parallel group if self._mp_process_group: dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group) return self._found_overflow.item() > 0 def zero_grad(self, set_to_none=True): """Set gradient to zero. Args: set_to_none (bool): Whether set the gradient to None. """ # set_to_none = True can save some memory space for param_group in self._optimizer.param_groups: zero_gard_by_list(param_group['params'], set_to_none=set_to_none) def _get_fp32_param_groups_to_update(self): return self._fp32_master_param_groups + self._fp32_param_groups def _unscale_grads(self): for group in self._get_fp32_param_groups_to_update(): for p in group: if p.grad is not None: p.grad.data.div_(self.loss_scale) def _assign_grad_to_fp32_master_param(self): # This only needs to be done for the float16 group. for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group): if fp16_param.grad is not None: fp32_param.grad = fp16_param.grad.float() # clear unneeded grad on fp16 param fp16_param.grad = None def _update_fp16_param_from_fp32_param(self): fp16_param_data = [] fp32_master_param_data = [] for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups): for fp16_param, fp32_param in zip(fp16_group, fp32_group): fp16_param_data.append(fp16_param.data) fp32_master_param_data.append(fp32_param.data) _multi_tensor_copy_this_to_that(this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf) def step(self): """Update the model parameters. """ # Copy gradients from model params to main params. self._assign_grad_to_fp32_master_param() self._unscale_grads() overflow = self._check_overflow() self._grad_scaler.update(overflow) if overflow: self.zero_grad() # Clip the main gradients. grad_norm = None if self._clip_grad_max_norm > 0.0: grad_norm = self.clip_grad_norm(self._clip_grad_max_norm) if not overflow: # Step the optimizer. self._optimizer.step() # Update params from main params. self._update_fp16_param_from_fp32_param() # Successful update. return True, grad_norm else: return False, None def backward(self, loss): """Execute backward pass. Args: loss (:class:`torch.Tensor`): the loss value. """ scaled_loss = loss * self.grad_scaler.scale scaled_loss.backward() def state_dict(self): """Returns the states of the fp16 optimizer as a dict object. """ state_dict = {} state_dict['optimizer'] = self._optimizer.state_dict() if self.grad_scaler: state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups return state_dict def load_state_dict(self, state_dict): """Load the states of the fp16 optimizer from a dict object. Args: state_dict (dict): the states of the fp16 optimizer """ # Optimizer. self._optimizer.load_state_dict(state_dict['optimizer']) # Grad scaler. if 'grad_scaler' in state_dict: self.grad_scaler.load_state_dict(state_dict['grad_scaler']) # Copy data for the main params. if 'fp32_master_param_groups' in state_dict: for current_group, ckpt_group in zip(self._fp32_master_param_groups, state_dict['fp32_master_param_groups']): for current_param, ckpt_param in zip(current_group, ckpt_group): current_param.data.copy_(ckpt_param.data) def clip_grad_norm(self, clip_grad): """Clip gradients by norm. Args: clip_grad (float): the max norm for clipping """ params = [] for param_group in self._optimizer.param_groups: for param in param_group['params']: params.append(param) return clip_grad_norm_fp32(params, clip_grad) # Promote state so it can be retrieved or set via # "optimizer_instance.state" def _get_state(self): return self._optimizer.state def _set_state(self, value): self._optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via # "optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self._optimizer.param_groups def _set_param_groups(self, value): self._optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups)