ColossalAI/colossalai/device/device_mesh.py

189 lines
8.3 KiB
Python

from functools import reduce
import operator
import torch
import torch.distributed as dist
class DeviceMesh:
"""A logical view of a physical mesh. The logical view is used in the
search process.
A physical mesh can have multiple logical views. (e.g., a 2x8 physical mesh
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
communication cost (default: None)
init_process_group (bool, optional): initialize logical process group
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
"""
def __init__(self,
physical_mesh_id,
mesh_shape,
mesh_alpha=None,
mesh_beta=None,
init_process_group=False,
need_flatten=True):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
# coefficient for alpha-beta communication model
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
self.init_process_group = init_process_group
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten:
self.flatten_device_mesh = self.flatten()
@property
def shape(self):
return self.mesh_shape
@property
def num_devices(self):
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
@property
def logical_mesh_id(self):
return self._logical_mesh_id
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
def _global_rank_to_logical_rank_map(self, tensor, index_list):
'''
This method is a helper function to build convert_map recursively.
'''
for index, inner_tensor in enumerate(tensor):
if inner_tensor.numel() == 1:
self.convert_map[int(inner_tensor)] = index_list + [index]
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
return process_groups_dict
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]
def global_rank_to_process_groups_with_logical_rank(self, rank):
'''
Give a global rank and return all logical process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
output:
# key is axis name
# value is a list of logical ranks in same axis with rank 0
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
'''
process_groups = {}
for d in range(self.logical_mesh_id.dim()):
for replacer in range(self.logical_mesh_id.shape[d]):
if d not in process_groups:
process_groups[d] = []
process_group_member = self.convert_map[rank].copy()
process_group_member[d] = replacer
process_groups[d].append(process_group_member)
return process_groups
def global_rank_to_process_groups_with_global_rank(self, rank):
'''
Give a global rank and return all process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
output:
# key is axis name
# value is a list of global ranks in same axis with rank 0
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
'''
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
process_groups = {}
for dim, logical_ranks in logical_process_groups.items():
process_groups[dim] = []
for logical_rank in logical_ranks:
for g_rank, l_rank in self.convert_map.items():
if l_rank == logical_rank:
process_groups[dim].append(g_rank)
return process_groups
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.1)
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
0.01)
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
0.001)
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)