2022-03-01 10:17:01 +00:00
|
|
|
import math
|
2023-01-13 02:05:58 +00:00
|
|
|
from typing import Optional
|
2022-11-29 05:00:30 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
import torch
|
2022-11-29 05:00:30 +00:00
|
|
|
import torch.distributed as dist
|
2023-07-04 09:41:28 +00:00
|
|
|
from torch import Tensor, inf
|
2022-03-01 10:17:01 +00:00
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
2023-07-04 09:41:28 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2022-11-29 05:00:30 +00:00
|
|
|
|
2023-01-29 07:09:57 +00:00
|
|
|
from colossalai.tensor import ColoParameter
|
2022-03-01 10:17:01 +00:00
|
|
|
from colossalai.utils import is_model_parallel_parameter
|
|
|
|
|
|
|
|
|
|
|
|
def flatten(input_):
|
|
|
|
return _flatten_dense_tensors(input_)
|
|
|
|
|
|
|
|
|
|
|
|
def unflatten(flat, tensors):
|
|
|
|
return _unflatten_dense_tensors(flat, tensors)
|
|
|
|
|
|
|
|
|
|
|
|
def count_numel(tensor_list):
|
|
|
|
res = 0
|
|
|
|
for tensor in tensor_list:
|
|
|
|
res += tensor.numel()
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_padding(numel, unit_size):
|
|
|
|
remainder = numel % unit_size
|
|
|
|
return unit_size - remainder if remainder else remainder
|
|
|
|
|
|
|
|
|
|
|
|
def shuffle_by_round_robin(tensor_list, num_partitions):
|
|
|
|
partitions = dict()
|
|
|
|
|
|
|
|
for tensor_idx, tensor in enumerate(tensor_list):
|
|
|
|
partition_to_go = tensor_idx % num_partitions
|
|
|
|
if partition_to_go not in partitions:
|
|
|
|
partitions[partition_to_go] = []
|
2022-04-11 15:13:02 +00:00
|
|
|
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
partitions_count = len(partitions)
|
|
|
|
new_tensor_list = []
|
|
|
|
tensor_index_mapping = dict()
|
|
|
|
|
|
|
|
for partition_id in range(partitions_count):
|
|
|
|
partition_tensors = partitions[partition_id]
|
|
|
|
for item in partition_tensors:
|
|
|
|
tensor_index_mapping[item['index']] = len(new_tensor_list)
|
|
|
|
new_tensor_list.append(item['tensor'])
|
|
|
|
|
|
|
|
return new_tensor_list, tensor_index_mapping
|
|
|
|
|
|
|
|
|
|
|
|
# create a flat tensor aligned at the alignment boundary
|
|
|
|
def flatten_dense_tensors_with_padding(tensor_list, unit_size):
|
|
|
|
num_elements = count_numel(tensor_list)
|
|
|
|
padding = calculate_padding(num_elements, unit_size=unit_size)
|
|
|
|
|
|
|
|
if padding > 0:
|
2022-04-11 15:13:02 +00:00
|
|
|
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
|
2022-03-01 10:17:01 +00:00
|
|
|
padded_tensor_list = tensor_list + [pad_tensor]
|
|
|
|
else:
|
|
|
|
padded_tensor_list = tensor_list
|
|
|
|
|
|
|
|
return flatten(padded_tensor_list)
|
|
|
|
|
|
|
|
|
|
|
|
def is_nccl_aligned(tensor):
|
|
|
|
return tensor.data_ptr() % 4 == 0
|
|
|
|
|
2022-04-11 15:13:02 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
def get_grad_accumulate_object(tensor):
|
|
|
|
"""
|
|
|
|
Return the AccumulateGrad of the input tensor
|
|
|
|
"""
|
|
|
|
|
|
|
|
# grad_fn reference:
|
|
|
|
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
|
|
|
|
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
|
|
|
|
#
|
|
|
|
# `next_functions` will return the backward graph where
|
|
|
|
# the first element is the AccumulateGrad of the leaf nodes.
|
|
|
|
# we want to get the AccumulateGrad of the input tensor instead of the leaf
|
|
|
|
# node in the whole computation graph.
|
|
|
|
# Therefore, we call expand_as to create a dummy graph
|
|
|
|
# where tensor_tmp and tensor indeed point to the same object.
|
|
|
|
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
|
|
|
|
tensor_tmp = tensor.expand_as(tensor)
|
|
|
|
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
|
|
|
|
return grad_acc_obj
|
|
|
|
|
|
|
|
|
2023-04-27 10:43:14 +00:00
|
|
|
def split_by_dtype(tensor_list):
|
|
|
|
"""
|
|
|
|
Splits a list of PyTorch tensors into sublists based on their data type.
|
|
|
|
|
|
|
|
:param tensor_list: A list of PyTorch tensors.
|
|
|
|
:type tensor_list: list[torch.Tensor]
|
|
|
|
:return: A list of sublists, where each sublist contains tensors of a specific data type.
|
|
|
|
:rtype: list[list[torch.Tensor]]
|
|
|
|
"""
|
2022-04-11 15:13:02 +00:00
|
|
|
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
2022-03-01 10:17:01 +00:00
|
|
|
buckets = []
|
2023-04-27 10:43:14 +00:00
|
|
|
for _, dtype in enumerate(dtypes):
|
2022-03-01 10:17:01 +00:00
|
|
|
bucket = [t for t in tensor_list if t.type() == dtype]
|
|
|
|
if bucket:
|
|
|
|
buckets.append(bucket)
|
|
|
|
return buckets
|
|
|
|
|
|
|
|
|
2023-01-13 06:56:17 +00:00
|
|
|
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
dst_local_rank: Optional[int] = None,
|
|
|
|
dst_global_rank: Optional[int] = None,
|
|
|
|
group: Optional[dist.ProcessGroup] = None):
|
2022-03-01 10:17:01 +00:00
|
|
|
"""
|
|
|
|
Reduce the tensor in the data parallel process group
|
|
|
|
|
|
|
|
:param tensor: A tensor object to reduce/all-reduce
|
|
|
|
:param dtype: The data type used in communication
|
|
|
|
:param dst_rank: The source rank for reduce. If dst_rank is None,
|
2022-11-29 05:00:30 +00:00
|
|
|
:param parallel_mode: Communication parallel mode
|
2022-03-01 10:17:01 +00:00
|
|
|
all-reduce will be used instead of reduce. Default is None.
|
|
|
|
|
|
|
|
:type tensor: torch.Tensor
|
2022-11-29 05:00:30 +00:00
|
|
|
:type dtype: torch.dtype, optional
|
2022-03-01 10:17:01 +00:00
|
|
|
:type dst_rank: int, optional
|
2023-01-13 02:05:58 +00:00
|
|
|
:type pg: ProcessGroup, optional
|
2022-03-01 10:17:01 +00:00
|
|
|
"""
|
2022-11-29 05:00:30 +00:00
|
|
|
# use the original dtype
|
|
|
|
if dtype is None:
|
|
|
|
dtype = tensor.dtype
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
# cast the data to specified dtype for reduce/all-reduce
|
|
|
|
if tensor.dtype != dtype:
|
|
|
|
tensor_to_reduce = tensor.to(dtype)
|
|
|
|
else:
|
|
|
|
tensor_to_reduce = tensor
|
|
|
|
|
2023-01-13 06:56:17 +00:00
|
|
|
world_size = dist.get_world_size(group=group)
|
2022-03-01 10:17:01 +00:00
|
|
|
tensor_to_reduce.div_(world_size)
|
|
|
|
|
|
|
|
# if rank is None, all reduce will be used
|
|
|
|
# else, reduce is used
|
2023-01-13 06:56:17 +00:00
|
|
|
use_all_reduce = dst_local_rank is None
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
if use_all_reduce:
|
|
|
|
dist.all_reduce(tensor_to_reduce, group=group)
|
|
|
|
else:
|
2023-01-13 06:56:17 +00:00
|
|
|
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
# recover the original dtype
|
|
|
|
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
2023-01-13 06:56:17 +00:00
|
|
|
local_rank = dist.get_rank(group=group)
|
|
|
|
if use_all_reduce or dst_local_rank == local_rank:
|
2022-03-01 10:17:01 +00:00
|
|
|
tensor.copy_(tensor_to_reduce)
|
2022-11-29 05:00:30 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
return tensor
|
|
|
|
|
2022-04-11 15:13:02 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
def has_inf_or_nan(tensor):
|
|
|
|
try:
|
|
|
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
|
|
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
|
|
|
# (which is true for some recent version of pytorch).
|
|
|
|
tensor_sum = float(tensor.float().sum())
|
|
|
|
# More efficient version that can be used if .sum() returns a Python scalar
|
|
|
|
# tensor_sum = float(tensor.sum())
|
|
|
|
except RuntimeError as instance:
|
|
|
|
# We want to check if inst is actually an overflow exception.
|
|
|
|
# RuntimeError could come from a different error.
|
|
|
|
# If so, we still want the exception to propagate.
|
|
|
|
if "value cannot be converted" not in instance.args[0]:
|
|
|
|
raise
|
|
|
|
return True
|
|
|
|
else:
|
2022-04-11 15:13:02 +00:00
|
|
|
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
2022-03-01 10:17:01 +00:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def release_param_grad(tensor_list):
|
|
|
|
for tensor in tensor_list:
|
|
|
|
tensor.grad = None
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_global_norm_from_list(norm_list):
|
|
|
|
""" Compute total from a list of norms
|
|
|
|
"""
|
|
|
|
total_norm = 0.0
|
|
|
|
for norm in norm_list:
|
|
|
|
total_norm += norm**2.0
|
|
|
|
return math.sqrt(total_norm)
|
|
|
|
|
|
|
|
|
2023-07-04 09:41:28 +00:00
|
|
|
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
2022-03-01 10:17:01 +00:00
|
|
|
"""Clips gradient norm of an iterable of parameters.
|
|
|
|
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
2023-07-04 09:41:28 +00:00
|
|
|
added functionality to handle model parallel parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
gradients (Tensor): The gradients to compute norm
|
|
|
|
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
|
|
|
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
|
|
|
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
Returns:
|
2023-07-04 09:41:28 +00:00
|
|
|
int: The total norm of given gradients
|
2022-03-01 10:17:01 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
norm_type = float(norm_type)
|
|
|
|
if norm_type == inf:
|
|
|
|
total_norm = max(g.data.abs().max() for g in gradients)
|
|
|
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
2022-04-11 15:13:02 +00:00
|
|
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
# Take max across all GPUs.
|
2023-07-04 09:41:28 +00:00
|
|
|
if tp_group is not None:
|
2022-04-11 15:13:02 +00:00
|
|
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
2022-03-01 10:17:01 +00:00
|
|
|
total_norm = total_norm_cuda[0].item()
|
|
|
|
else:
|
|
|
|
total_norm = 0.0
|
2023-07-04 09:41:28 +00:00
|
|
|
for g in gradients:
|
|
|
|
param_norm = g.data.double().norm(2)
|
|
|
|
total_norm += param_norm.item()**2
|
2022-04-11 15:13:02 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
# Sum across all model parallel GPUs.
|
|
|
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
2022-04-11 15:13:02 +00:00
|
|
|
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
|
|
|
|
2023-07-04 09:41:28 +00:00
|
|
|
if tp_group is not None:
|
|
|
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
|
|
|
|
2022-04-11 15:13:02 +00:00
|
|
|
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
2022-03-01 10:17:01 +00:00
|
|
|
total_norm = -1
|
|
|
|
|
|
|
|
return total_norm
|
|
|
|
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def sync_tensor(flat_tensor, tensor_list):
|
2022-03-01 10:17:01 +00:00
|
|
|
"""
|
|
|
|
Synchronize the flattened tensor and unflattened tensor list. When
|
|
|
|
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
|
|
|
|
a new tensor is created. Thus, the flat tensor and original tensor list do not
|
2022-11-29 05:00:30 +00:00
|
|
|
share the same memory space. This function will update the tensor list so that
|
2022-03-01 10:17:01 +00:00
|
|
|
they point to the same value.
|
|
|
|
|
2023-06-07 16:01:29 +00:00
|
|
|
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor list
|
2022-03-01 10:17:01 +00:00
|
|
|
:param tensor_list: A list of tensors corresponding to the flattened tensor
|
|
|
|
:type flat_tensor: torch.Tensor
|
|
|
|
:type tensor_list: List[torch.Tensor]
|
|
|
|
"""
|
|
|
|
updated_params = unflatten(flat_tensor, tensor_list)
|
|
|
|
|
|
|
|
# update the tensor data
|
|
|
|
for p, q in zip(tensor_list, updated_params):
|
|
|
|
p.data = q.data
|