ColossalAI/colossalai/device/device_mesh.py

412 lines
18 KiB
Python
Raw Normal View History

"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
import operator
2023-06-08 02:18:17 +00:00
from dataclasses import dataclass
from functools import reduce
2023-06-08 02:18:17 +00:00
from typing import Dict, List, Union
import torch
import torch.distributed as dist
2023-06-08 02:18:17 +00:00
from torch.distributed import ProcessGroup
@dataclass
class ProcessGroupContainer:
process_group: ProcessGroup
ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
class DeviceMesh:
"""A logical view of a physical cluster. For example, we could view a physical cluster
with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
2023-06-08 02:18:17 +00:00
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
2023-06-08 02:18:17 +00:00
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
2022-08-25 08:48:12 +00:00
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
2023-06-08 02:18:17 +00:00
device: str = 'cuda'):
# ============================
# Physical & Logical Mesh IDs
# ============================
self._physical_mesh_id = physical_mesh_id
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
2023-06-08 02:18:17 +00:00
self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
2023-06-08 02:18:17 +00:00
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
2023-06-08 02:18:17 +00:00
# alpha is latency and beta is bandwidth
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
2023-06-08 02:18:17 +00:00
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
2023-06-08 02:18:17 +00:00
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(self.mesh_beta), \
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
# =========================
self._device = device
self._dist_backend = self._DIST_BACKEND[device]
# =========================
# Process Group Management
# =========================
# the _global_to_local_rank_mapping is structured as follows
# {
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
tensor=self.logical_mesh_id)
# create process group
self._process_group_dict = {}
self._ranks_in_the_process_group = {}
self._global_rank_of_current_process = None
self._is_initialized = False
# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
if init_process_group:
self.init_logical_process_group()
@property
2023-06-08 02:18:17 +00:00
def shape(self) -> torch.Size:
"""
Return the shape of the logical mesh.
"""
return self.mesh_shape
@property
2023-06-08 02:18:17 +00:00
def num_devices(self) -> int:
"""
Return the number of devices contained in the device mesh.
"""
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property
2023-06-08 02:18:17 +00:00
def logical_mesh_id(self) -> torch.Tensor:
"""
Return the logical mesh id.
"""
return self._logical_mesh_id
2023-06-08 02:18:17 +00:00
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
Return the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._process_group_dict[global_rank][axis]
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
"""
Return the process groups for all axes.
Args:
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._process_group_dict[global_rank]
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
"""
Return the ranks in the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._ranks_in_the_process_group[global_rank][axis]
def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
2023-06-08 02:18:17 +00:00
# process group cannot be copied
# thus, we share them directly
setattr(result, k, v)
return result
2023-06-08 02:18:17 +00:00
def _init_global_to_logical_rank_mapping(self,
mapping: Dict,
tensor: torch.Tensor,
index_list: List[int] = []) -> Dict[int, List[int]]:
2022-08-25 08:48:12 +00:00
"""
2023-06-08 02:18:17 +00:00
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
2022-08-25 08:48:12 +00:00
2023-06-08 02:18:17 +00:00
Args:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
index_list (List[int])
Returns:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
The value is a list of integers and each integer represents the local rank in the indexed axis.
"""
for index, inner_tensor in enumerate(tensor):
2023-06-08 02:18:17 +00:00
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank
if inner_tensor.numel() == 1:
2023-06-08 02:18:17 +00:00
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
mapping[int(inner_tensor)] = index_list + [index]
else:
2023-06-08 02:18:17 +00:00
# we recursively go into the function until we reach the last axis
# meanwhile, we should add the local rank in the current axis in the index_list
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
2023-06-08 02:18:17 +00:00
def init_logical_process_group(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
2023-06-08 02:18:17 +00:00
# sanity check
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
duplicate_check_list = []
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
2023-06-08 02:18:17 +00:00
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
2023-06-08 02:18:17 +00:00
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# skip duplicated process group creation
if ranks_in_same_group in duplicate_check_list:
continue
2023-06-08 02:18:17 +00:00
# create the process group
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
2023-06-08 02:18:17 +00:00
# keep this process group in the process_groups_dict
for rank in ranks_in_same_group:
if rank not in self._process_group_dict:
self._process_group_dict[rank] = dict()
self._process_group_dict[rank][axis] = pg_handler
# update the init flag
# we only allow init for once
self._is_initialized = True
def _init_ranks_in_the_same_group(self):
"""
This method is used to initialize the ranks_in_the_same_group dictionary.
"""
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# create dict for each rank
if global_rank not in self._process_group_dict:
self._ranks_in_the_process_group[global_rank] = dict()
# keep this process group in the process_groups_dict
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
"""
Return the local rank of the given global rank in the logical device mesh.
Args:
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
else:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
'''
2023-06-08 02:18:17 +00:00
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
```python
sphysical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.collate_global_ranks_in_same_process_group(0))
# key is axis name
# value is a list of global ranks in same axis with rank 0
# output will look like
# {
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
'''
2023-06-08 02:18:17 +00:00
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
# the value is the list of local ranks corresponding to the global rank with respect of different axes
# we can see the list of local ranks as the process coordinates for simplicity
# the key and value are all unique, therefore,
# we can also to use the coordinates to find the global rank
# =========================================================================
# Step 1
# find all the process_coordinates for processes in the same process group
# as the given global rank
# =========================================================================
# each
processes_in_the_same_process_group = {}
for dim in range(self.logical_mesh_id.dim()):
# iterate over the dimension size so that we can include all processes
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initailized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
# get the local rank corresponding to the global rank
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
# replace the local rank in the given dimension with the
# lcoal rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)
# =================================================================
# Step 2
# Use local rank combination to find its corresponding global rank
# =================================================================
# the key of the dict is the axis
# the value is the list of global ranks which are in the same process group as the given global rank
global_pg_ranks = {}
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
global_pg_ranks[dim] = []
for process_coordinates in coordinates_of_all_processes:
# find the global rank by local rank combination
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
if process_coordinates == _process_coordinates:
global_pg_ranks[dim].append(_global_rank)
return global_pg_ranks
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)