mirror of https://github.com/InternLM/InternLM
316 lines
11 KiB
Python
316 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.utils.logger import get_logger
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
def flatten(input_):
|
|
return _flatten_dense_tensors(input_)
|
|
|
|
|
|
def unflatten(flat, tensors):
|
|
return _unflatten_dense_tensors(flat, tensors)
|
|
|
|
|
|
def get_grad_accumulate_object(tensor):
|
|
"""
|
|
Return the AccumulateGrad of the input tensor
|
|
"""
|
|
|
|
# grad_fn reference:
|
|
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
|
|
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
|
|
#
|
|
# `next_functions` will return the backward graph where
|
|
# the first element is the AccumulateGrad of the leaf nodes.
|
|
# we want to get the AccumulateGrad of the input tensor instead of the leaf
|
|
# node in the whole computation graph.
|
|
# Therefore, we call expand_as to create a dummy graph
|
|
# where tensor_tmp and tensor indeed point to the same object.
|
|
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
|
|
tensor_tmp = tensor.expand_as(tensor)
|
|
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
|
|
return grad_acc_obj
|
|
|
|
|
|
def split_half_float_double(tensor_list):
|
|
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
|
buckets = []
|
|
for _, dtype in enumerate(dtypes):
|
|
bucket = [t for t in tensor_list if t.type() == dtype]
|
|
if bucket:
|
|
buckets.append(bucket)
|
|
return buckets
|
|
|
|
|
|
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
|
|
"""
|
|
Reduce the tensor in the data parallel process group
|
|
|
|
:param tensor: A tensor object to reduce/all-reduce
|
|
:param dtype: The data type used in communication
|
|
:param dst_rank: The source rank for reduce. If dst_rank is None,
|
|
:param parallel_mode: Communication parallel mode
|
|
all-reduce will be used instead of reduce. Default is None.
|
|
|
|
:type tensor: torch.Tensor
|
|
:type dtype: torch.dtype, optional
|
|
:type dst_rank: int, optional
|
|
:type parallel_mode: ParallelMode, optional
|
|
"""
|
|
# use the original dtype
|
|
if dtype is None:
|
|
dtype = tensor.dtype
|
|
|
|
# cast the data to specified dtype for reduce/all-reduce
|
|
if tensor.dtype != dtype:
|
|
tensor_to_reduce = tensor.to(dtype)
|
|
else:
|
|
tensor_to_reduce = tensor
|
|
|
|
world_size = gpc.get_world_size(parallel_mode)
|
|
group = gpc.get_group(parallel_mode)
|
|
tensor_to_reduce.div_(world_size)
|
|
|
|
# if rank is None, all reduce will be used
|
|
# else, reduce is used
|
|
use_all_reduce = dst_rank is None
|
|
|
|
if use_all_reduce:
|
|
dist.all_reduce(tensor_to_reduce, group=group)
|
|
else:
|
|
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
|
global_rank = ranks_in_group[dst_rank]
|
|
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
|
|
|
# recover the original dtype
|
|
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
|
local_rank = gpc.get_local_rank(parallel_mode)
|
|
if use_all_reduce or dst_rank == local_rank:
|
|
tensor.copy_(tensor_to_reduce)
|
|
|
|
return tensor
|
|
|
|
|
|
def has_inf_or_nan(tensor):
|
|
try:
|
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
|
# (which is true for some recent version of pytorch).
|
|
tensor_sum = float(tensor.float().sum())
|
|
# More efficient version that can be used if .sum() returns a Python scalar
|
|
# tensor_sum = float(tensor.sum())
|
|
except RuntimeError as instance:
|
|
# We want to check if inst is actually an overflow exception.
|
|
# RuntimeError could come from a different error.
|
|
# If so, we still want the exception to propagate.
|
|
if "value cannot be converted" not in instance.args[0]:
|
|
raise
|
|
return True
|
|
else:
|
|
if tensor_sum == float("inf") or tensor_sum == -float("inf"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def release_param_grad(tensor_list):
|
|
for tensor in tensor_list:
|
|
tensor.grad = None
|
|
|
|
|
|
def sync_param(flat_tensor, tensor_list):
|
|
"""
|
|
Synchronize the flattened tensor and unflattened tensor list. When
|
|
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
|
a new tensor is created. Thus, the flat tensor and original tensor list do not
|
|
share the same memory space. This function will update the tensor list so that
|
|
they point to the same value.
|
|
|
|
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
|
|
:param tensor_list: A list of tensors corresponding to the flattened tensor
|
|
:type flat_tensor: torch.Tensor
|
|
:type tensor_list: List[torch.Tensor]
|
|
"""
|
|
updated_params = unflatten(flat_tensor, tensor_list)
|
|
|
|
# update the tensor data
|
|
for p, q in zip(tensor_list, updated_params):
|
|
p.data = q.data
|
|
|
|
|
|
class BaseGradScaler(ABC):
|
|
"""A base class for the gradient scaler.
|
|
|
|
Args:
|
|
initial_scale (float): the initial loss scale
|
|
"""
|
|
|
|
def __init__(self, initial_scale: float):
|
|
assert initial_scale > 0
|
|
self._scale = torch.cuda.FloatTensor([initial_scale])
|
|
|
|
@property
|
|
def scale(self) -> Tensor:
|
|
"""Returns the loss scale."""
|
|
|
|
return self._scale
|
|
|
|
@property
|
|
def inv_scale(self) -> Tensor:
|
|
"""Returns the inverse of the loss scale."""
|
|
|
|
return self._scale.double().reciprocal().float()
|
|
|
|
def state_dict(self) -> Dict:
|
|
"""Returns the states of the gradient scaler as a dict object."""
|
|
|
|
state_dict = dict()
|
|
state_dict["scale"] = self.scale
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict: Dict) -> None:
|
|
"""Load the states of the gradient scaler from a dict object.
|
|
|
|
Args:
|
|
state_dict (dict): the states of the gradient scaler
|
|
"""
|
|
|
|
self._scale = state_dict["scale"]
|
|
|
|
@abstractmethod
|
|
def update(self, overflow: bool) -> None:
|
|
"""Update the loss scale.
|
|
|
|
Args:
|
|
overflow (bool): whether overflow occurs
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class DynamicGradScaler(BaseGradScaler):
|
|
"""A gradient scaler which uses dynamic loss scale
|
|
|
|
Args:
|
|
initial_scale (float): the initial loss scale, defaults to 2**16
|
|
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
|
|
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
|
|
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
|
|
min_scale (float): the minimum loss scale, defaults to None
|
|
max_scale (float): the maximum loss scale, defaults to None
|
|
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
initial_scale: float = 2**16,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
min_scale: Optional[float] = None,
|
|
max_scale: Optional[float] = None,
|
|
hysteresis: int = 2,
|
|
):
|
|
super().__init__(initial_scale)
|
|
if min_scale:
|
|
self._min_scale = torch.cuda.FloatTensor([min_scale])
|
|
else:
|
|
self._min_scale = None
|
|
|
|
if max_scale:
|
|
self._max_scale = torch.cuda.FloatTensor([max_scale])
|
|
else:
|
|
self._max_scale = None
|
|
|
|
self._growth_factor = growth_factor
|
|
self._backoff_factor = backoff_factor
|
|
self._growth_interval = growth_interval
|
|
self._growth_step = 0
|
|
self._hysteresis = hysteresis
|
|
self._hysteresis_step = 0
|
|
self._sanity_checks()
|
|
|
|
def _sanity_checks(self) -> None:
|
|
"""Check if the arguments are correct."""
|
|
|
|
if self._min_scale:
|
|
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
|
|
if self._max_scale:
|
|
assert self._min_scale > 0, "The maximum gradient scale cannot be zero or negative"
|
|
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
|
|
assert self._backoff_factor < 1 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1"
|
|
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
|
|
|
|
def update(self, overflow: bool) -> None:
|
|
"""Update the loss scale.
|
|
|
|
Args:
|
|
overflow (bool): whether overflow occurs
|
|
"""
|
|
if overflow:
|
|
self._hysteresis_step += 1
|
|
self._growth_step = 0
|
|
|
|
if self._hysteresis_step >= self._hysteresis:
|
|
self._backoff_scale()
|
|
if gpc.is_rank_for_log():
|
|
logger.warning(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}")
|
|
else:
|
|
self._growth_step += 1
|
|
if self._growth_step == self._growth_interval:
|
|
self._growth_step = 0
|
|
self._hysteresis_step = 0
|
|
self._grow_scale()
|
|
if gpc.is_rank_for_log():
|
|
logger.warning(
|
|
f"No overflow for consecutive {self._growth_interval} steps, "
|
|
f"the loss scale is adjusted to {self.scale.item()}",
|
|
)
|
|
|
|
def _backoff_scale(self) -> None:
|
|
"""Decrease the loss scale"""
|
|
|
|
self._scale = self._scale * self._backoff_factor
|
|
if self._min_scale:
|
|
self._scale = torch.max(self._scale, self._min_scale)
|
|
|
|
def _grow_scale(self) -> None:
|
|
"""Increase the loss scale"""
|
|
|
|
self._scale = self._scale * self._growth_factor
|
|
if self._max_scale:
|
|
self._scale = torch.min(self._scale, self._max_scale)
|
|
|
|
def state_dict(self):
|
|
"""Returns the states of the gradient scaler as a dict object."""
|
|
|
|
state_dict = dict()
|
|
state_dict["_scale"] = self._scale.item()
|
|
state_dict["_growth_step"] = self._growth_step
|
|
state_dict["_hysteresis_step"] = self._hysteresis_step
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Load the states of the gradient scaler from a dict object.
|
|
|
|
Args:
|
|
state_dict (dict): the states of the gradient scaler
|
|
"""
|
|
|
|
self._scale = self._scale.fill_(state_dict["_scale"])
|
|
self._growth_step = state_dict["_growth_step"]
|
|
self._hysteresis_step = state_dict["_hysteresis_step"]
|