Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

525 lines
23 KiB

"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
import operator
from dataclasses import dataclass
from functools import reduce
from typing import Dict, List, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@dataclass
class ProcessGroupContainer:
process_group: ProcessGroup
ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
class DeviceMesh:
"""A logical view of a physical cluster. For example, we could view a physical cluster
with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
def __init__(
self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
device: str = "cuda",
):
# ============================
# Physical & Logical Mesh IDs
# ============================
self._physical_mesh_id = physical_mesh_id
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, (
"Only one of mesh_shape and logical_mesh_id can be specified."
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
)
if logical_mesh_id is None:
self._mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self._mesh_shape = self._logical_mesh_id.shape
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(
torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert (
torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert (
torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
# alpha is latency and beta is bandwidth
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self._mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self._mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(
self.mesh_beta
), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
# =========================
self._device = device
self._dist_backend = self._DIST_BACKEND[device]
# =========================
# Process Group Management
# =========================
# the _global_to_local_rank_mapping is structured as follows
# {
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(
mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
)
# create process group
self._process_group_dict = {}
self._ranks_in_the_process_group = {}
self._global_rank_of_current_process = None
self._is_initialized = False
# attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
# is known if created with from_process_group
self._is_init_from_process_group = False
# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
if init_process_group:
self.init_logical_process_group()
@property
def shape(self) -> torch.Size:
"""
Return the shape of the logical mesh.
"""
return self._mesh_shape
@property
def num_devices(self) -> int:
"""
Return the number of devices contained in the device mesh.
"""
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property
def logical_mesh_id(self) -> torch.Tensor:
"""
Return the logical mesh id.
"""
return self._logical_mesh_id
@property
def is_initialized(self) -> bool:
"""
Return whether the process group is initialized.
"""
return self._is_initialized
@staticmethod
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
"""
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
Args:
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
Returns:
DeviceMesh: the device mesh instance.
"""
def _get_device_by_backend(process_group):
"""
Get the device type given a process group's backend.
"""
backend = dist.get_backend(process_group)
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
if _backend == backend:
return _device
return None
if isinstance(process_group, ProcessGroup):
process_group = [process_group]
# get mesh shape
mesh_shape = [dist.get_world_size(pg) for pg in process_group]
# get device
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
assert all(
[device == device_list[0] for device in device_list]
), "All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
# we cannot get the global ranks for all processes in the mesh
# therefore, we only use this fake physical mesh id to create the device mesh
# and will remove this fake physical mesh id later
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
# create the device mesh
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
# hack the device attribute
device_mesh._physical_mesh_id = None
device_mesh._logical_mesh_id = None
device_mesh._global_rank_of_current_process = dist.get_rank()
device_mesh._is_initialized = False
device_mesh._process_group_dict = {
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
}
return device_mesh
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
Return the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank][axis]
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
"""
Return the process groups for all axes.
Args:
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank]
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
"""
Return the ranks in the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._ranks_in_the_process_group[global_rank][axis]
def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
# thus, we share them directly
setattr(result, k, v)
return result
def _init_global_to_logical_rank_mapping(
self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
) -> Dict[int, List[int]]:
"""
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
Args:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
index_list (List[int])
Returns:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
The value is a list of integers and each integer represents the local rank in the indexed axis.
"""
for index, inner_tensor in enumerate(tensor):
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank
if inner_tensor.numel() == 1:
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
mapping[int(inner_tensor)] = index_list + [index]
else:
# we recursively go into the function until we reach the last axis
# meanwhile, we should add the local rank in the current axis in the index_list
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
"""
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
"""
# sanity check
assert (
dist.is_initialized
), "The torch.distributed should be initialized before calling init_logical_process_group"
assert (
not self._is_initialized
), "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
duplicate_check_list = []
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# skip duplicated process group creation
if ranks_in_same_group in duplicate_check_list:
continue
# create the process group
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
# keep this process group in the process_groups_dict
for rank in ranks_in_same_group:
if rank not in self._process_group_dict:
self._process_group_dict[rank] = dict()
self._process_group_dict[rank][axis] = pg_handler
# update the init flag
# we only allow init for once
self._is_initialized = True
def _init_ranks_in_the_same_group(self):
"""
This method is used to initialize the ranks_in_the_same_group dictionary.
"""
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# create dict for each rank
if global_rank not in self._process_group_dict:
self._ranks_in_the_process_group[global_rank] = dict()
# keep this process group in the process_groups_dict
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
"""
Return the local rank of the given global rank in the logical device mesh.
Args:
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
else:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
"""
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
```python
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.collate_global_ranks_in_same_process_group(0))
# key is axis name
# value is a list of global ranks in same axis with rank 0
# output will look like
# {
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
"""
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
# the value is the list of local ranks corresponding to the global rank with respect of different axes
# we can see the list of local ranks as the process coordinates for simplicity
# the key and value are all unique, therefore,
# we can also to use the coordinates to find the global rank
# =========================================================================
# Step 1
# find all the process_coordinates for processes in the same process group
# as the given global rank
# =========================================================================
# each
processes_in_the_same_process_group = {}
for dim in range(self.logical_mesh_id.dim()):
# iterate over the dimension size so that we can include all processes
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
# get the local rank corresponding to the global rank
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
# replace the local rank in the given dimension with the
# local rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)
# =================================================================
# Step 2
# Use local rank combination to find its corresponding global rank
# =================================================================
# the key of the dict is the axis
# the value is the list of global ranks which are in the same process group as the given global rank
global_pg_ranks = {}
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
global_pg_ranks[dim] = []
for process_coordinates in coordinates_of_all_processes:
# find the global rank by local rank combination
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
if process_coordinates == _process_coordinates:
global_pg_ranks[dim].append(_global_rank)
return global_pg_ranks
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(
self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group,
)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ 0.01
)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (
self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (
self.mesh_alpha[mesh_dim]
+ self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ 0.001
)