mirror of https://github.com/hpcaitech/ColossalAI
506 lines
23 KiB
Python
506 lines
23 KiB
Python
"""This code is adapted from Alpa
|
|
https://github.com/alpa-projects/alpa/
|
|
with some changes. """
|
|
|
|
import operator
|
|
from dataclasses import dataclass
|
|
from functools import reduce
|
|
from typing import Dict, List, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
@dataclass
|
|
class ProcessGroupContainer:
|
|
process_group: ProcessGroup
|
|
ranks: List[int]
|
|
|
|
|
|
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
|
|
class DeviceMesh:
|
|
"""A logical view of a physical cluster. For example, we could view a physical cluster
|
|
with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
|
|
|
|
Arguments:
|
|
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
|
|
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
|
|
mesh_shape (torch.Size, optional): shape of logical view.
|
|
mesh_alpha (List[float], optional): coefficients used for computing
|
|
communication cost (default: None)
|
|
mesh_beta (List[float], optional): coefficients used for computing
|
|
communication cost (default: None)
|
|
init_process_group (bool, optional): initialize logical process group
|
|
during initializing the DeviceMesh instance if the init_process_group set to True.
|
|
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
|
|
(default: False)
|
|
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
|
"""
|
|
|
|
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
|
|
|
def __init__(self,
|
|
physical_mesh_id: torch.Tensor,
|
|
mesh_shape: torch.Size = None,
|
|
logical_mesh_id: torch.Tensor = None,
|
|
mesh_alpha: List[float] = None,
|
|
mesh_beta: List[float] = None,
|
|
init_process_group: bool = False,
|
|
device: str = 'cuda'):
|
|
# ============================
|
|
# Physical & Logical Mesh IDs
|
|
# ============================
|
|
self._physical_mesh_id = physical_mesh_id
|
|
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
|
|
|
|
# logical mesh ids can be obtained via two ways
|
|
# 1. provide physical mesh id and provide mesh shape
|
|
# 2. directly supply the logical mesh id
|
|
assert mesh_shape is None or logical_mesh_id is None, \
|
|
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
|
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
|
|
|
|
if logical_mesh_id is None:
|
|
self._mesh_shape = mesh_shape
|
|
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
|
|
else:
|
|
self._logical_mesh_id = logical_mesh_id
|
|
self._mesh_shape = self._logical_mesh_id.shape
|
|
|
|
# ensure two things:
|
|
# 1. logical and physical mesh IDs should contain the same elements
|
|
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
|
|
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
|
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
|
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
|
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
|
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
|
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
|
|
|
# ===============================================
|
|
# coefficient for alpha-beta communication model
|
|
# alpha is latency and beta is bandwidth
|
|
# ===============================================
|
|
# if the values are not provided, we assume they are 1 for simplicity
|
|
if mesh_alpha is None:
|
|
mesh_alpha = [1] * len(self._mesh_shape)
|
|
if mesh_beta is None:
|
|
mesh_beta = [1] * len(self._mesh_shape)
|
|
|
|
self.mesh_alpha = tuple(mesh_alpha)
|
|
self.mesh_beta = tuple(mesh_beta)
|
|
|
|
# ensure the alpha and beta have the same shape
|
|
assert len(self.mesh_alpha) == len(self.mesh_beta), \
|
|
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
|
|
|
|
# =========================
|
|
# Device for Process Group
|
|
# =========================
|
|
self._device = device
|
|
self._dist_backend = self._DIST_BACKEND[device]
|
|
|
|
# =========================
|
|
# Process Group Management
|
|
# =========================
|
|
# the _global_to_local_rank_mapping is structured as follows
|
|
# {
|
|
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
|
|
# }
|
|
self._global_to_local_rank_mapping = dict()
|
|
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
|
|
tensor=self.logical_mesh_id)
|
|
|
|
# create process group
|
|
self._process_group_dict = {}
|
|
self._ranks_in_the_process_group = {}
|
|
self._global_rank_of_current_process = None
|
|
self._is_initialized = False
|
|
|
|
# attribute used to inidicate whether this objectd
|
|
# is created using DeviceMesh.from_process_group
|
|
# this attribute can be used to do some check in methods
|
|
# such get_process_group as no global rank information
|
|
# is known if created with from_process_group
|
|
self._is_init_from_process_group = False
|
|
|
|
# initialize process group if specified
|
|
self._init_ranks_in_the_same_group()
|
|
self._init_process_group = init_process_group
|
|
if init_process_group:
|
|
self.init_logical_process_group()
|
|
|
|
@property
|
|
def shape(self) -> torch.Size:
|
|
"""
|
|
Return the shape of the logical mesh.
|
|
"""
|
|
return self._mesh_shape
|
|
|
|
@property
|
|
def num_devices(self) -> int:
|
|
"""
|
|
Return the number of devices contained in the device mesh.
|
|
"""
|
|
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
|
|
|
|
@property
|
|
def logical_mesh_id(self) -> torch.Tensor:
|
|
"""
|
|
Return the logical mesh id.
|
|
"""
|
|
return self._logical_mesh_id
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""
|
|
Return whether the process group is initialized.
|
|
"""
|
|
return self._is_initialized
|
|
|
|
@staticmethod
|
|
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
|
|
"""
|
|
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
|
|
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
|
|
|
|
Args:
|
|
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
|
|
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
|
|
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
|
|
|
|
Returns:
|
|
DeviceMesh: the device mesh instance.
|
|
"""
|
|
|
|
def _get_device_by_backend(process_group):
|
|
"""
|
|
Get the device type given a process group's backend.
|
|
"""
|
|
backend = dist.get_backend(process_group)
|
|
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
|
|
if _backend == backend:
|
|
return _device
|
|
return None
|
|
|
|
if isinstance(process_group, ProcessGroup):
|
|
process_group = [process_group]
|
|
|
|
# get mesh shape
|
|
mesh_shape = [dist.get_world_size(pg) for pg in process_group]
|
|
|
|
# get device
|
|
device_list = [_get_device_by_backend(pg) for pg in process_group]
|
|
|
|
# make sure all devices are the same
|
|
assert all([device == device_list[0] for device in device_list]), \
|
|
"All devices should be the same, please check your input process groups are created with the same distributed backend."
|
|
|
|
# create a fake physical mesh id
|
|
# as we only get the process group associated with the current process,
|
|
# we cannot get the global ranks for all processes in the mesh
|
|
# therefore, we only use this fake physical mesh id to create the device mesh
|
|
# and will remove this fake physical mesh id later
|
|
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
|
|
|
|
# create the device mesh
|
|
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
|
|
|
|
# hack the device attribute
|
|
device_mesh._physical_mesh_id = None
|
|
device_mesh._logical_mesh_id = None
|
|
device_mesh._global_rank_of_current_process = dist.get_rank()
|
|
device_mesh._is_initialized = False
|
|
device_mesh._process_group_dict = {
|
|
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
|
|
}
|
|
|
|
return device_mesh
|
|
|
|
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
|
|
"""
|
|
Return the process group on the specified axis.
|
|
|
|
Args:
|
|
axis (int): the axis of the process group.
|
|
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
|
|
"""
|
|
if global_rank is None:
|
|
global_rank = self._global_rank_of_current_process
|
|
elif self._is_init_from_process_group:
|
|
raise RuntimeError(
|
|
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
|
)
|
|
return self._process_group_dict[global_rank][axis]
|
|
|
|
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
|
|
"""
|
|
Return the process groups for all axes.
|
|
|
|
Args:
|
|
global_rank (int, optional): the global rank of the process
|
|
"""
|
|
if global_rank is None:
|
|
global_rank = self._global_rank_of_current_process
|
|
elif self._is_init_from_process_group:
|
|
raise RuntimeError(
|
|
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
|
)
|
|
return self._process_group_dict[global_rank]
|
|
|
|
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
|
|
"""
|
|
Return the ranks in the process group on the specified axis.
|
|
|
|
Args:
|
|
axis (int): the axis of the process group.
|
|
global_rank (int, optional): the global rank of the process
|
|
"""
|
|
if global_rank is None:
|
|
global_rank = self._global_rank_of_current_process
|
|
elif self._is_init_from_process_group:
|
|
raise RuntimeError(
|
|
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
|
)
|
|
return self._ranks_in_the_process_group[global_rank][axis]
|
|
|
|
def __deepcopy__(self, memo) -> "DeviceMesh":
|
|
cls = self.__class__
|
|
result = cls.__new__(cls)
|
|
memo[id(self)] = result
|
|
for k, v in self.__dict__.items():
|
|
if k != '_process_group_dict':
|
|
setattr(result, k, __import__("copy").deepcopy(v, memo))
|
|
else:
|
|
# process group cannot be copied
|
|
# thus, we share them directly
|
|
setattr(result, k, v)
|
|
return result
|
|
|
|
def _init_global_to_logical_rank_mapping(self,
|
|
mapping: Dict,
|
|
tensor: torch.Tensor,
|
|
index_list: List[int] = []) -> Dict[int, List[int]]:
|
|
"""
|
|
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
|
|
|
|
Args:
|
|
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
|
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
|
|
index_list (List[int])
|
|
|
|
Returns:
|
|
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
|
The value is a list of integers and each integer represents the local rank in the indexed axis.
|
|
"""
|
|
for index, inner_tensor in enumerate(tensor):
|
|
# index means the local rank in the current axis
|
|
# inner_tensor refers to the processes with the same local rank
|
|
|
|
if inner_tensor.numel() == 1:
|
|
# if the inner_tensor only has one element, it means that
|
|
# it already reaches the last axis
|
|
# we append its local_rank in the last axis to the index_list
|
|
# and assign to the mapping
|
|
# the value of the mapping is the the local rank at the indexed axis of the device mesh
|
|
mapping[int(inner_tensor)] = index_list + [index]
|
|
else:
|
|
# we recursively go into the function until we reach the last axis
|
|
# meanwhile, we should add the local rank in the current axis in the index_list
|
|
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
|
|
|
|
def init_logical_process_group(self):
|
|
'''
|
|
This method is used to initialize the logical process groups which will be used in communications
|
|
among logical device mesh.
|
|
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
|
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
|
'''
|
|
# sanity check
|
|
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
|
|
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
|
|
|
|
# update the global rank of the current process
|
|
self._global_rank_of_current_process = dist.get_rank()
|
|
duplicate_check_list = []
|
|
|
|
# flatten the global ranks to 1D list
|
|
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
|
|
|
for global_rank in global_rank_flatten_list:
|
|
# find the other ranks which are in the same process group as global_rank
|
|
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
|
|
|
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
|
# skip duplicated process group creation
|
|
if ranks_in_same_group in duplicate_check_list:
|
|
continue
|
|
|
|
# create the process group
|
|
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
|
|
|
|
# keep this process group in the process_groups_dict
|
|
for rank in ranks_in_same_group:
|
|
if rank not in self._process_group_dict:
|
|
self._process_group_dict[rank] = dict()
|
|
self._process_group_dict[rank][axis] = pg_handler
|
|
|
|
# update the init flag
|
|
# we only allow init for once
|
|
self._is_initialized = True
|
|
|
|
def _init_ranks_in_the_same_group(self):
|
|
"""
|
|
This method is used to initialize the ranks_in_the_same_group dictionary.
|
|
"""
|
|
# flatten the global ranks to 1D list
|
|
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
|
|
|
for global_rank in global_rank_flatten_list:
|
|
# find the other ranks which are in the same process group as global_rank
|
|
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
|
|
|
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
|
# create dict for each rank
|
|
if global_rank not in self._process_group_dict:
|
|
self._ranks_in_the_process_group[global_rank] = dict()
|
|
|
|
# keep this process group in the process_groups_dict
|
|
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
|
|
|
|
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
|
|
"""
|
|
Return the local rank of the given global rank in the logical device mesh.
|
|
|
|
Args:
|
|
rank (int): the global rank in the logical device mesh.
|
|
axis (int): the axis of the logical device mesh.
|
|
"""
|
|
if self._is_init_from_process_group:
|
|
raise RuntimeError(
|
|
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
|
)
|
|
|
|
local_ranks = self._global_to_local_rank_mapping[rank]
|
|
if axis:
|
|
return local_ranks[axis]
|
|
else:
|
|
return local_ranks
|
|
|
|
def _collate_global_ranks_in_same_process_group(self, global_rank):
|
|
'''
|
|
Give a global rank and return all global ranks involved in its associated process group in each axis.
|
|
|
|
Example:
|
|
|
|
```python
|
|
sphysical_mesh_id = torch.arange(0, 16)
|
|
mesh_shape = (4, 4)
|
|
|
|
# logical mesh will look like
|
|
# [[0, 1, 2, 3],
|
|
# [4, 5, 6, 7],
|
|
# [8, 9, 10,11],
|
|
# [12,13,14,15]]
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
print(device_mesh.collate_global_ranks_in_same_process_group(0))
|
|
|
|
# key is axis name
|
|
# value is a list of global ranks in same axis with rank 0
|
|
# output will look like
|
|
# {
|
|
0: [0, 4, 8, 12],
|
|
1: [0, 1, 2, 3]
|
|
# }
|
|
'''
|
|
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
|
|
# for self._global_to_local_rank_mapping
|
|
# the key is the global rank
|
|
# the value is the list of local ranks corresponding to the global rank with respect of different axes
|
|
# we can see the list of local ranks as the process coordinates for simplicity
|
|
# the key and value are all unique, therefore,
|
|
# we can also to use the coordinates to find the global rank
|
|
|
|
# =========================================================================
|
|
# Step 1
|
|
# find all the process_coordinates for processes in the same process group
|
|
# as the given global rank
|
|
# =========================================================================
|
|
|
|
# each
|
|
processes_in_the_same_process_group = {}
|
|
|
|
for dim in range(self.logical_mesh_id.dim()):
|
|
# iterate over the dimension size so that we can include all processes
|
|
# in the same process group in the given axis
|
|
# the _local_rank refers to the local rank of the current process
|
|
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
|
|
|
# if this dimension is not initailized yet,
|
|
# initialize it with an empty array
|
|
if dim not in processes_in_the_same_process_group:
|
|
processes_in_the_same_process_group[dim] = []
|
|
|
|
# get the local rank corresponding to the global rank
|
|
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
|
|
|
|
# replace the local rank in the given dimension with the
|
|
# lcoal rank of the current process iterated
|
|
process_coordinates[dim] = _local_rank
|
|
processes_in_the_same_process_group[dim].append(process_coordinates)
|
|
|
|
# =================================================================
|
|
# Step 2
|
|
# Use local rank combination to find its corresponding global rank
|
|
# =================================================================
|
|
# the key of the dict is the axis
|
|
# the value is the list of global ranks which are in the same process group as the given global rank
|
|
global_pg_ranks = {}
|
|
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
|
|
global_pg_ranks[dim] = []
|
|
for process_coordinates in coordinates_of_all_processes:
|
|
# find the global rank by local rank combination
|
|
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
|
|
if process_coordinates == _process_coordinates:
|
|
global_pg_ranks[dim].append(_global_rank)
|
|
return global_pg_ranks
|
|
|
|
def flatten(self):
|
|
"""
|
|
Flatten the logical mesh into an effective 1d logical mesh,
|
|
"""
|
|
if self._is_init_from_process_group:
|
|
raise RuntimeError(
|
|
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
|
)
|
|
|
|
flatten_mesh_shape_size = len(self._mesh_shape)
|
|
flatten_mesh_shape = [self.num_devices]
|
|
return DeviceMesh(self._physical_mesh_id,
|
|
tuple(flatten_mesh_shape),
|
|
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
|
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
|
init_process_group=self._init_process_group)
|
|
|
|
def all_gather_cost(self, num_bytes, mesh_dim):
|
|
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
|
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
|
|
0.1)
|
|
|
|
def all_reduce_cost(self, num_bytes, mesh_dim):
|
|
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
|
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
|
|
0.01)
|
|
|
|
def reduce_scatter_cost(self, num_bytes, mesh_dim):
|
|
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
|
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
|
|
0.001)
|
|
|
|
def all_to_all_cost(self, num_bytes, mesh_dim):
|
|
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
|
penalty_factor = num_devices / 2.0
|
|
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
|
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) |