2022-03-01 10:17:01 +00:00
import math
2023-01-13 02:05:58 +00:00
from typing import Optional
2022-11-29 05:00:30 +00:00
2022-03-01 10:17:01 +00:00
import torch
2022-11-29 05:00:30 +00:00
import torch.distributed as dist
2023-07-04 09:41:28 +00:00
from torch import Tensor, inf
2022-03-01 10:17:01 +00:00
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
2023-07-04 09:41:28 +00:00
from torch.distributed import ProcessGroup
2022-11-29 05:00:30 +00:00
2022-03-01 10:17:01 +00:00
def flatten(input_):
return _flatten_dense_tensors(input_)
def unflatten(flat, tensors):
return _unflatten_dense_tensors(flat, tensors)
def count_numel(tensor_list):
res = 0
for tensor in tensor_list:
res += tensor.numel()
return res
def calculate_padding(numel, unit_size):
remainder = numel % unit_size
return unit_size - remainder if remainder else remainder
def shuffle_by_round_robin(tensor_list, num_partitions):
partitions = dict()
for tensor_idx, tensor in enumerate(tensor_list):
partition_to_go = tensor_idx % num_partitions
if partition_to_go not in partitions:
partitions[partition_to_go] = []
2022-04-11 15:13:02 +00:00
partitions[partition_to_go].append(dict(tensor=tensor, index=tensor_idx))
2022-03-01 10:17:01 +00:00
partitions_count = len(partitions)
new_tensor_list = []
tensor_index_mapping = dict()
for partition_id in range(partitions_count):
partition_tensors = partitions[partition_id]
for item in partition_tensors:
2023-09-19 06:20:26 +00:00
tensor_index_mapping[item["index"]] = len(new_tensor_list)
2022-03-01 10:17:01 +00:00
return new_tensor_list, tensor_index_mapping
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_with_padding(tensor_list, unit_size):
num_elements = count_numel(tensor_list)
padding = calculate_padding(num_elements, unit_size=unit_size)
if padding > 0:
2022-04-11 15:13:02 +00:00
pad_tensor = torch.zeros(padding, device=tensor_list[0].device, dtype=tensor_list[0].dtype)
2022-03-01 10:17:01 +00:00
padded_tensor_list = tensor_list + [pad_tensor]
padded_tensor_list = tensor_list
return flatten(padded_tensor_list)
def is_nccl_aligned(tensor):
return tensor.data_ptr() % 4 == 0
2022-04-11 15:13:02 +00:00
2022-03-01 10:17:01 +00:00
def get_grad_accumulate_object(tensor):
Return the AccumulateGrad of the input tensor
# grad_fn reference:
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
# `next_functions` will return the backward graph where
# the first element is the AccumulateGrad of the leaf nodes.
# we want to get the AccumulateGrad of the input tensor instead of the leaf
# node in the whole computation graph.
# Therefore, we call expand_as to create a dummy graph
# where tensor_tmp and tensor indeed point to the same object.
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
tensor_tmp = tensor.expand_as(tensor)
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
return grad_acc_obj
2023-04-27 10:43:14 +00:00
def split_by_dtype(tensor_list):
Splits a list of PyTorch tensors into sublists based on their data type.
:param tensor_list: A list of PyTorch tensors.
:type tensor_list: list[torch.Tensor]
:return: A list of sublists, where each sublist contains tensors of a specific data type.
:rtype: list[list[torch.Tensor]]
2022-04-11 15:13:02 +00:00
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
2022-03-01 10:17:01 +00:00
buckets = []
2023-04-27 10:43:14 +00:00
for _, dtype in enumerate(dtypes):
2022-03-01 10:17:01 +00:00
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
return buckets
2023-09-19 06:20:26 +00:00
def reduce_tensor_dp_group(
tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None,
2022-03-01 10:17:01 +00:00
Reduce the tensor in the data parallel process group
:param tensor: A tensor object to reduce/all-reduce
:param dtype: The data type used in communication
:param dst_rank: The source rank for reduce. If dst_rank is None,
2022-11-29 05:00:30 +00:00
:param parallel_mode: Communication parallel mode
2022-03-01 10:17:01 +00:00
all-reduce will be used instead of reduce. Default is None.
:type tensor: torch.Tensor
2022-11-29 05:00:30 +00:00
:type dtype: torch.dtype, optional
2022-03-01 10:17:01 +00:00
:type dst_rank: int, optional
2023-01-13 02:05:58 +00:00
:type pg: ProcessGroup, optional
2022-03-01 10:17:01 +00:00
2022-11-29 05:00:30 +00:00
# use the original dtype
if dtype is None:
dtype = tensor.dtype
2022-03-01 10:17:01 +00:00
# cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype:
tensor_to_reduce = tensor.to(dtype)
tensor_to_reduce = tensor
2023-01-13 06:56:17 +00:00
world_size = dist.get_world_size(group=group)
2022-03-01 10:17:01 +00:00
# if rank is None, all reduce will be used
# else, reduce is used
2023-01-13 06:56:17 +00:00
use_all_reduce = dst_local_rank is None
2022-03-01 10:17:01 +00:00
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
2023-01-13 06:56:17 +00:00
dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
2022-03-01 10:17:01 +00:00
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
2023-01-13 06:56:17 +00:00
local_rank = dist.get_rank(group=group)
if use_all_reduce or dst_local_rank == local_rank:
2022-03-01 10:17:01 +00:00
2022-11-29 05:00:30 +00:00
2022-03-01 10:17:01 +00:00
return tensor
2022-04-11 15:13:02 +00:00
2022-03-01 10:17:01 +00:00
def has_inf_or_nan(tensor):
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum = float(tensor.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
return True
2023-09-19 06:20:26 +00:00
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
2022-03-01 10:17:01 +00:00
return True
return False
def release_param_grad(tensor_list):
for tensor in tensor_list:
tensor.grad = None
def calculate_global_norm_from_list(norm_list):
2023-09-19 06:20:26 +00:00
"""Compute total from a list of norms"""
2022-03-01 10:17:01 +00:00
total_norm = 0.0
for norm in norm_list:
total_norm += norm**2.0
return math.sqrt(total_norm)
2023-07-04 09:41:28 +00:00
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
2022-03-01 10:17:01 +00:00
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
2023-07-04 09:41:28 +00:00
added functionality to handle model parallel parameters.
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
2022-03-01 10:17:01 +00:00
2023-07-04 09:41:28 +00:00
int: The total norm of given gradients
2022-03-01 10:17:01 +00:00
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
2022-04-11 15:13:02 +00:00
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
2022-03-01 10:17:01 +00:00
# Take max across all GPUs.
2023-07-04 09:41:28 +00:00
if tp_group is not None:
2022-04-11 15:13:02 +00:00
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
2022-03-01 10:17:01 +00:00
total_norm = total_norm_cuda[0].item()
total_norm = 0.0
2023-07-04 09:41:28 +00:00
for g in gradients:
2023-09-27 02:35:24 +00:00
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
2022-04-11 15:13:02 +00:00
2022-03-01 10:17:01 +00:00
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
2022-04-11 15:13:02 +00:00
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
2023-07-04 09:41:28 +00:00
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
2022-03-01 10:17:01 +00:00
2023-09-19 06:20:26 +00:00
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
2022-03-01 10:17:01 +00:00
2023-09-19 06:20:26 +00:00
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
2022-03-01 10:17:01 +00:00
total_norm = -1
return total_norm
2023-06-30 07:30:50 +00:00
def sync_tensor(flat_tensor, tensor_list):
2022-03-01 10:17:01 +00:00
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
a new tensor is created. Thus, the flat tensor and original tensor list do not
2022-11-29 05:00:30 +00:00
share the same memory space. This function will update the tensor list so that
2022-03-01 10:17:01 +00:00
they point to the same value.
2023-06-07 16:01:29 +00:00
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor list
2022-03-01 10:17:01 +00:00
:param tensor_list: A list of tensors corresponding to the flattened tensor
:type flat_tensor: torch.Tensor
:type tensor_list: List[torch.Tensor]
updated_params = unflatten(flat_tensor, tensor_list)
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data