#!/usr/bin/env python # -*- encoding: utf-8 -*- import math from abc import ABC, abstractmethod from typing import Dict, Optional import torch import torch.distributed as dist from torch import Tensor from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger from internlm.utils.parallel import is_model_parallel_parameter logger = get_logger(__file__) try: import amp_C from apex.multi_tensor_apply import multi_tensor_applier APEX_AVAILABLE = True except (ModuleNotFoundError, ImportError): logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!") APEX_AVAILABLE = False inf = math.inf def flatten(input_): return _flatten_dense_tensors(input_) def unflatten(flat, tensors): return _unflatten_dense_tensors(flat, tensors) def get_grad_accumulate_object(tensor): """ Return the AccumulateGrad of the input tensor """ # grad_fn reference: # https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463 # expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand # # `next_functions` will return the backward graph where # the first element is the AccumulateGrad of the leaf nodes. # we want to get the AccumulateGrad of the input tensor instead of the leaf # node in the whole computation graph. # Therefore, we call expand_as to create a dummy graph # where tensor_tmp and tensor indeed point to the same object. # You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr()) tensor_tmp = tensor.expand_as(tensor) grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0] return grad_acc_obj def split_half_float_double(tensor_list): dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] buckets = [] for _, dtype in enumerate(dtypes): bucket = [t for t in tensor_list if t.type() == dtype] if bucket: buckets.append(bucket) return buckets def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA): """ Reduce the tensor in the data parallel process group :param tensor: A tensor object to reduce/all-reduce :param dtype: The data type used in communication :param dst_rank: The source rank for reduce. If dst_rank is None, :param parallel_mode: Communication parallel mode all-reduce will be used instead of reduce. Default is None. :type tensor: torch.Tensor :type dtype: torch.dtype, optional :type dst_rank: int, optional :type parallel_mode: ParallelMode, optional """ # use the original dtype if dtype is None: dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce if tensor.dtype != dtype: tensor_to_reduce = tensor.to(dtype) else: tensor_to_reduce = tensor world_size = gpc.get_world_size(parallel_mode) group = gpc.get_group(parallel_mode) tensor_to_reduce.div_(world_size) # if rank is None, all reduce will be used # else, reduce is used use_all_reduce = dst_rank is None if use_all_reduce: dist.all_reduce(tensor_to_reduce, group=group) else: ranks_in_group = gpc.get_ranks_in_group(parallel_mode) global_rank = ranks_in_group[dst_rank] dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group) # recover the original dtype if tensor.dtype != dtype and tensor is not tensor_to_reduce: local_rank = gpc.get_local_rank(parallel_mode) if use_all_reduce or dst_rank == local_rank: tensor.copy_(tensor_to_reduce) return tensor def has_inf_or_nan(tensor): try: # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if # Pytorch's .sum() creates a one-element tensor of the same type as tensor # (which is true for some recent version of pytorch). tensor_sum = float(tensor.float().sum()) # More efficient version that can be used if .sum() returns a Python scalar # tensor_sum = float(tensor.sum()) except RuntimeError as instance: # We want to check if inst is actually an overflow exception. # RuntimeError could come from a different error. # If so, we still want the exception to propagate. if "value cannot be converted" not in instance.args[0]: raise return True else: if tensor_sum == float("inf") or tensor_sum == -float("inf"): return True return False def release_param_grad(tensor_list): for tensor in tensor_list: tensor.grad = None def sync_param(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a new tensor is created. Thus, the flat tensor and original tensor list do not share the same memory space. This function will update the tensor list so that they point to the same value. :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit :param tensor_list: A list of tensors corresponding to the flattened tensor :type flat_tensor: torch.Tensor :type tensor_list: List[torch.Tensor] """ updated_params = unflatten(flat_tensor, tensor_list) # update the tensor data for p, q in zip(tensor_list, updated_params): p.data = q.data def multi_tensor_l2norm_torch(tensor_list, per_tensor): # Convert tensor_list elements to torch.float32 tensor_list = [tensor.float() for tensor in tensor_list] norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list]) l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0) if per_tensor: per_tensor_norm = norms_tensor else: per_tensor_norm = torch.Tensor([]).to(norms_tensor.device) return l2_norm, per_tensor_norm def calc_l2_norm(grads): norm = 0.0 if len(grads) > 0: if APEX_AVAILABLE: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm ) else: norm, _ = multi_tensor_l2norm_torch(grads, False) return norm def calc_lp(grads, norm_type): norm = 0.0 for grad in grads: grad_norm = torch.norm(grad, norm_type) norm += grad_norm**norm_type return norm def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2): """Get the norm Arguments: gradients (Iterable[Tensor]): The gradient value. parameters (Iterable[Tensor]): The parameter each gradient corresponds to. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters, need total_norm**(1/norm) before using. """ enable_cuda_kernels = gradients[0].device.type == "cuda" # Norm parameters. norm_type = float(norm_type) # Calculate norm. if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device) if last_stage is False: return total_norm_cuda if previous_norm is not None: total_norm_cuda = max(total_norm_cuda, previous_norm) # Take max across all model-parallel GPUs. if gpc.get_world_size(ParallelMode.MODEL) > 1: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] for g, p in zip(gradients, parameters): # TODO: consider the pipeline shared parameter if ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") and dist.get_rank(p.pipeline_shared_module_pg) == 0 ): # if shared between different pipe, only count o tensor_parallel_grads.append(g.data.float()) elif ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") and dist.get_rank(p.pipeline_shared_module_pg) != 0 ): continue elif ( gpc.is_initialized(ParallelMode.TENSOR) and not is_model_parallel_parameter(p) and gpc.get_local_rank(ParallelMode.TENSOR) == 0 ): # if not used in each chunk, such as layernorm tensor_parallel_grads.append(g.data.float()) elif is_model_parallel_parameter(p): tensor_parallel_grads.append(g.data.float()) elif gpc.get_local_rank(ParallelMode.TENSOR) != 0: continue else: raise RuntimeError("Should not arrive here") if norm_type == 2.0 and enable_cuda_kernels: tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type else: tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type) # If norm is type of float, then we convert them into torch.Tensor. tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors if not enable_cuda_kernels: tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm) total_norm = tensor_parallel_norm if last_stage is False: return total_norm if previous_norm is not None: total_norm = total_norm + previous_norm # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL): dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) # This is because we use zero1, so we need to use this reduction. # TODO: Check zero group to be a subset of dp group. dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1)) if torch.is_tensor(total_norm): total_norm = total_norm.item() # Scale. if total_norm == float("inf") or total_norm == -float("inf"): total_norm = -1 return total_norm class BaseGradScaler(ABC): """A base class for the gradient scaler. Args: initial_scale (float): the initial loss scale """ def __init__(self, initial_scale: float): assert initial_scale > 0 self._scale = torch.cuda.FloatTensor([initial_scale]) @property def scale(self) -> Tensor: """Returns the loss scale.""" return self._scale @property def inv_scale(self) -> Tensor: """Returns the inverse of the loss scale.""" return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() state_dict["scale"] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: """Load the states of the gradient scaler from a dict object. Args: state_dict (dict): the states of the gradient scaler """ self._scale = state_dict["scale"] @abstractmethod def update(self, overflow: bool) -> None: """Update the loss scale. Args: overflow (bool): whether overflow occurs """ pass class DynamicGradScaler(BaseGradScaler): """A gradient scaler which uses dynamic loss scale Args: initial_scale (float): the initial loss scale, defaults to 2**16 growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2 backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5 growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000 min_scale (float): the minimum loss scale, defaults to None max_scale (float): the maximum loss scale, defaults to None hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2 """ def __init__( self, initial_scale: float = 2**16, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, min_scale: Optional[float] = None, max_scale: Optional[float] = None, hysteresis: int = 2, ): super().__init__(initial_scale) if min_scale: self._min_scale = torch.cuda.FloatTensor([min_scale]) else: self._min_scale = None if max_scale: self._max_scale = torch.cuda.FloatTensor([max_scale]) else: self._max_scale = None self._growth_factor = growth_factor self._backoff_factor = backoff_factor self._growth_interval = growth_interval self._growth_step = 0 self._hysteresis = hysteresis self._hysteresis_step = 0 self._sanity_checks() def _sanity_checks(self) -> None: """Check if the arguments are correct.""" if self._min_scale: assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative" if self._max_scale: assert self._min_scale > 0, "The maximum gradient scale cannot be zero or negative" assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1" assert self._backoff_factor < 1 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1" assert self._hysteresis >= 0, "The hysteresis cannot be negative" def update(self, overflow: bool) -> None: """Update the loss scale. Args: overflow (bool): whether overflow occurs """ if overflow: self._hysteresis_step += 1 self._growth_step = 0 if self._hysteresis_step >= self._hysteresis: self._backoff_scale() if gpc.is_rank_for_log(): logger.warning(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}") else: self._growth_step += 1 if self._growth_step == self._growth_interval: self._growth_step = 0 self._hysteresis_step = 0 self._grow_scale() if gpc.is_rank_for_log(): logger.warning( f"No overflow for consecutive {self._growth_interval} steps, " f"the loss scale is adjusted to {self.scale.item()}", ) def _backoff_scale(self) -> None: """Decrease the loss scale""" self._scale = self._scale * self._backoff_factor if self._min_scale: self._scale = torch.max(self._scale, self._min_scale) def _grow_scale(self) -> None: """Increase the loss scale""" self._scale = self._scale * self._growth_factor if self._max_scale: self._scale = torch.min(self._scale, self._max_scale) def state_dict(self): """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() state_dict["_scale"] = self._scale.item() state_dict["_growth_step"] = self._growth_step state_dict["_hysteresis_step"] = self._hysteresis_step return state_dict def load_state_dict(self, state_dict): """Load the states of the gradient scaler from a dict object. Args: state_dict (dict): the states of the gradient scaler """ self._scale = self._scale.fill_(state_dict["_scale"]) self._growth_step = state_dict["_growth_step"] self._hysteresis_step = state_dict["_hysteresis_step"]