mirror of https://github.com/hpcaitech/ColossalAI
[device] support init device mesh from process group (#3990)
parent
a2f9af810d
commit
611971248c
|
@ -3,11 +3,19 @@
|
|||
with some changes. """
|
||||
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessGroupContainer:
|
||||
process_group: ProcessGroup
|
||||
ranks: List[int]
|
||||
|
||||
|
||||
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
|
||||
|
@ -27,9 +35,11 @@ class DeviceMesh:
|
|||
during initializing the DeviceMesh instance if the init_process_group set to True.
|
||||
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
|
||||
(default: False)
|
||||
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
|
||||
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
|
||||
"""
|
||||
|
||||
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
|
@ -37,160 +47,442 @@ class DeviceMesh:
|
|||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
need_flatten: bool = True):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
device: str = 'cuda'):
|
||||
# ============================
|
||||
# Physical & Logical Mesh IDs
|
||||
# ============================
|
||||
self._physical_mesh_id = physical_mesh_id
|
||||
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
|
||||
|
||||
# logical mesh ids can be obtained via two ways
|
||||
# 1. provide physical mesh id and provide mesh shape
|
||||
# 2. directly supply the logical mesh id
|
||||
assert mesh_shape is None or logical_mesh_id is None, \
|
||||
"Only one of mesh_shape and logical_mesh_id can be specified." \
|
||||
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
|
||||
|
||||
if logical_mesh_id is None:
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
self._mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
|
||||
else:
|
||||
self._logical_mesh_id = logical_mesh_id
|
||||
self.mesh_shape = self._logical_mesh_id.shape
|
||||
self._mesh_shape = self._logical_mesh_id.shape
|
||||
|
||||
# map global rank into logical rank
|
||||
self.convert_map = {}
|
||||
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
|
||||
# ensure two things:
|
||||
# 1. logical and physical mesh IDs should contain the same elements
|
||||
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
|
||||
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
|
||||
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
|
||||
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
|
||||
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
|
||||
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
|
||||
|
||||
# ===============================================
|
||||
# coefficient for alpha-beta communication model
|
||||
# alpha is latency and beta is bandwidth
|
||||
# ===============================================
|
||||
# if the values are not provided, we assume they are 1 for simplicity
|
||||
if mesh_alpha is None:
|
||||
mesh_alpha = [1] * len(self.mesh_shape)
|
||||
mesh_alpha = [1] * len(self._mesh_shape)
|
||||
if mesh_beta is None:
|
||||
mesh_beta = [1] * len(self.mesh_shape)
|
||||
mesh_beta = [1] * len(self._mesh_shape)
|
||||
|
||||
self.mesh_alpha = tuple(mesh_alpha)
|
||||
self.mesh_beta = tuple(mesh_beta)
|
||||
self.init_process_group = init_process_group
|
||||
self.need_flatten = need_flatten
|
||||
if self.init_process_group:
|
||||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
# self.mesh_beta)
|
||||
|
||||
# ensure the alpha and beta have the same shape
|
||||
assert len(self.mesh_alpha) == len(self.mesh_beta), \
|
||||
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
|
||||
|
||||
# =========================
|
||||
# Device for Process Group
|
||||
# =========================
|
||||
self._device = device
|
||||
self._dist_backend = self._DIST_BACKEND[device]
|
||||
|
||||
# =========================
|
||||
# Process Group Management
|
||||
# =========================
|
||||
# the _global_to_local_rank_mapping is structured as follows
|
||||
# {
|
||||
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
|
||||
# }
|
||||
self._global_to_local_rank_mapping = dict()
|
||||
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
|
||||
tensor=self.logical_mesh_id)
|
||||
|
||||
# create process group
|
||||
self._process_group_dict = {}
|
||||
self._ranks_in_the_process_group = {}
|
||||
self._global_rank_of_current_process = None
|
||||
self._is_initialized = False
|
||||
|
||||
# attribute used to inidicate whether this objectd
|
||||
# is created using DeviceMesh.from_process_group
|
||||
# this attribute can be used to do some check in methods
|
||||
# such get_process_group as no global rank information
|
||||
# is known if created with from_process_group
|
||||
self._is_init_from_process_group = False
|
||||
|
||||
# initialize process group if specified
|
||||
self._init_ranks_in_the_same_group()
|
||||
self._init_process_group = init_process_group
|
||||
if init_process_group:
|
||||
self.init_logical_process_group()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.mesh_shape
|
||||
def shape(self) -> torch.Size:
|
||||
"""
|
||||
Return the shape of the logical mesh.
|
||||
"""
|
||||
return self._mesh_shape
|
||||
|
||||
@property
|
||||
def num_devices(self):
|
||||
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
|
||||
def num_devices(self) -> int:
|
||||
"""
|
||||
Return the number of devices contained in the device mesh.
|
||||
"""
|
||||
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
|
||||
|
||||
@property
|
||||
def logical_mesh_id(self):
|
||||
def logical_mesh_id(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the logical mesh id.
|
||||
"""
|
||||
return self._logical_mesh_id
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""
|
||||
Return whether the process group is initialized.
|
||||
"""
|
||||
return self._is_initialized
|
||||
|
||||
@staticmethod
|
||||
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
|
||||
"""
|
||||
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
|
||||
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
|
||||
|
||||
Args:
|
||||
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
|
||||
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
|
||||
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: the device mesh instance.
|
||||
"""
|
||||
|
||||
def _get_device_by_backend(process_group):
|
||||
"""
|
||||
Get the device type given a process group's backend.
|
||||
"""
|
||||
backend = dist.get_backend(process_group)
|
||||
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
|
||||
if _backend == backend:
|
||||
return _device
|
||||
return None
|
||||
|
||||
if isinstance(process_group, ProcessGroup):
|
||||
process_group = [process_group]
|
||||
|
||||
# get mesh shape
|
||||
mesh_shape = [dist.get_world_size(pg) for pg in process_group]
|
||||
|
||||
# get device
|
||||
device_list = [_get_device_by_backend(pg) for pg in process_group]
|
||||
|
||||
# make sure all devices are the same
|
||||
assert all([device == device_list[0] for device in device_list]), \
|
||||
"All devices should be the same, please check your input process groups are created with the same distributed backend."
|
||||
|
||||
# create a fake physical mesh id
|
||||
# as we only get the process group associated with the current process,
|
||||
# we cannot get the global ranks for all processes in the mesh
|
||||
# therefore, we only use this fake physical mesh id to create the device mesh
|
||||
# and will remove this fake physical mesh id later
|
||||
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
|
||||
|
||||
# create the device mesh
|
||||
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
|
||||
|
||||
# hack the device attribute
|
||||
device_mesh._physical_mesh_id = None
|
||||
device_mesh._logical_mesh_id = None
|
||||
device_mesh._global_rank_of_current_process = dist.get_rank()
|
||||
device_mesh._is_initialized = False
|
||||
device_mesh._process_group_dict = {
|
||||
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
|
||||
}
|
||||
|
||||
return device_mesh
|
||||
|
||||
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
|
||||
"""
|
||||
Return the process group on the specified axis.
|
||||
|
||||
Args:
|
||||
axis (int): the axis of the process group.
|
||||
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
elif self._is_init_from_process_group:
|
||||
raise RuntimeError(
|
||||
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
||||
)
|
||||
return self._process_group_dict[global_rank][axis]
|
||||
|
||||
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
|
||||
"""
|
||||
Return the process groups for all axes.
|
||||
|
||||
Args:
|
||||
global_rank (int, optional): the global rank of the process
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
elif self._is_init_from_process_group:
|
||||
raise RuntimeError(
|
||||
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
||||
)
|
||||
return self._process_group_dict[global_rank]
|
||||
|
||||
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
|
||||
"""
|
||||
Return the ranks in the process group on the specified axis.
|
||||
|
||||
Args:
|
||||
axis (int): the axis of the process group.
|
||||
global_rank (int, optional): the global rank of the process
|
||||
"""
|
||||
if global_rank is None:
|
||||
global_rank = self._global_rank_of_current_process
|
||||
elif self._is_init_from_process_group:
|
||||
raise RuntimeError(
|
||||
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
||||
)
|
||||
return self._ranks_in_the_process_group[global_rank][axis]
|
||||
|
||||
def __deepcopy__(self, memo) -> "DeviceMesh":
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
if k != 'process_groups_dict':
|
||||
if k != '_process_group_dict':
|
||||
setattr(result, k, __import__("copy").deepcopy(v, memo))
|
||||
else:
|
||||
# process group cannot be copied
|
||||
# thus, we share them directly
|
||||
setattr(result, k, v)
|
||||
|
||||
return result
|
||||
|
||||
def flatten(self):
|
||||
def _init_global_to_logical_rank_mapping(self,
|
||||
mapping: Dict,
|
||||
tensor: torch.Tensor,
|
||||
index_list: List[int] = []) -> Dict[int, List[int]]:
|
||||
"""
|
||||
Flatten the logical mesh into an effective 1d logical mesh,
|
||||
"""
|
||||
flatten_mesh_shape_size = len(self.mesh_shape)
|
||||
flatten_mesh_shape = [self.num_devices]
|
||||
return DeviceMesh(self.physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self.init_process_group,
|
||||
need_flatten=False)
|
||||
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
|
||||
|
||||
def _global_rank_to_logical_rank_map(self, tensor, index_list):
|
||||
'''
|
||||
This method is a helper function to build convert_map recursively.
|
||||
'''
|
||||
Args:
|
||||
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
||||
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
|
||||
index_list (List[int])
|
||||
|
||||
Returns:
|
||||
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
|
||||
The value is a list of integers and each integer represents the local rank in the indexed axis.
|
||||
"""
|
||||
for index, inner_tensor in enumerate(tensor):
|
||||
if inner_tensor.numel() == 1:
|
||||
self.convert_map[int(inner_tensor)] = index_list + [index]
|
||||
else:
|
||||
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
|
||||
# index means the local rank in the current axis
|
||||
# inner_tensor refers to the processes with the same local rank
|
||||
|
||||
def create_process_groups_for_logical_mesh(self):
|
||||
if inner_tensor.numel() == 1:
|
||||
# if the inner_tensor only has one element, it means that
|
||||
# it already reaches the last axis
|
||||
# we append its local_rank in the last axis to the index_list
|
||||
# and assign to the mapping
|
||||
# the value of the mapping is the the local rank at the indexed axis of the device mesh
|
||||
mapping[int(inner_tensor)] = index_list + [index]
|
||||
else:
|
||||
# we recursively go into the function until we reach the last axis
|
||||
# meanwhile, we should add the local rank in the current axis in the index_list
|
||||
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
|
||||
|
||||
def init_logical_process_group(self):
|
||||
'''
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
process_groups_dict = {}
|
||||
check_duplicate_list = []
|
||||
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
|
||||
# sanity check
|
||||
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
|
||||
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
|
||||
|
||||
# update the global rank of the current process
|
||||
self._global_rank_of_current_process = dist.get_rank()
|
||||
duplicate_check_list = []
|
||||
|
||||
# flatten the global ranks to 1D list
|
||||
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
||||
|
||||
for global_rank in global_rank_flatten_list:
|
||||
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
|
||||
for axis, process_group in process_groups.items():
|
||||
if axis not in process_groups_dict:
|
||||
process_groups_dict[axis] = []
|
||||
if process_group not in check_duplicate_list:
|
||||
check_duplicate_list.append(process_group)
|
||||
process_group_handler = dist.new_group(process_group)
|
||||
process_groups_dict[axis].append((process_group, process_group_handler))
|
||||
# find the other ranks which are in the same process group as global_rank
|
||||
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
||||
|
||||
return process_groups_dict
|
||||
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
||||
# skip duplicated process group creation
|
||||
if ranks_in_same_group in duplicate_check_list:
|
||||
continue
|
||||
|
||||
def global_rank_to_logical_rank(self, rank):
|
||||
return self.convert_map[rank]
|
||||
# create the process group
|
||||
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
|
||||
|
||||
def global_rank_to_process_groups_with_logical_rank(self, rank):
|
||||
'''
|
||||
Give a global rank and return all logical process groups of this rank.
|
||||
for example:
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
|
||||
output:
|
||||
# key is axis name
|
||||
# value is a list of logical ranks in same axis with rank 0
|
||||
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
|
||||
'''
|
||||
process_groups = {}
|
||||
for d in range(self.logical_mesh_id.dim()):
|
||||
for replacer in range(self.logical_mesh_id.shape[d]):
|
||||
if d not in process_groups:
|
||||
process_groups[d] = []
|
||||
process_group_member = self.convert_map[rank].copy()
|
||||
process_group_member[d] = replacer
|
||||
process_groups[d].append(process_group_member)
|
||||
return process_groups
|
||||
# keep this process group in the process_groups_dict
|
||||
for rank in ranks_in_same_group:
|
||||
if rank not in self._process_group_dict:
|
||||
self._process_group_dict[rank] = dict()
|
||||
self._process_group_dict[rank][axis] = pg_handler
|
||||
|
||||
def global_rank_to_process_groups_with_global_rank(self, rank):
|
||||
# update the init flag
|
||||
# we only allow init for once
|
||||
self._is_initialized = True
|
||||
|
||||
def _init_ranks_in_the_same_group(self):
|
||||
"""
|
||||
This method is used to initialize the ranks_in_the_same_group dictionary.
|
||||
"""
|
||||
# flatten the global ranks to 1D list
|
||||
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
|
||||
|
||||
for global_rank in global_rank_flatten_list:
|
||||
# find the other ranks which are in the same process group as global_rank
|
||||
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
|
||||
|
||||
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
|
||||
# create dict for each rank
|
||||
if global_rank not in self._process_group_dict:
|
||||
self._ranks_in_the_process_group[global_rank] = dict()
|
||||
|
||||
# keep this process group in the process_groups_dict
|
||||
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
|
||||
|
||||
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
|
||||
"""
|
||||
Return the local rank of the given global rank in the logical device mesh.
|
||||
|
||||
Args:
|
||||
rank (int): the global rank in the logical device mesh.
|
||||
axis (int): the axis of the logical device mesh.
|
||||
"""
|
||||
if self._is_init_from_process_group:
|
||||
raise RuntimeError(
|
||||
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
||||
)
|
||||
|
||||
local_ranks = self._global_to_local_rank_mapping[rank]
|
||||
if axis:
|
||||
return local_ranks[axis]
|
||||
else:
|
||||
return local_ranks
|
||||
|
||||
def _collate_global_ranks_in_same_process_group(self, global_rank):
|
||||
'''
|
||||
Give a global rank and return all process groups of this rank.
|
||||
for example:
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
|
||||
output:
|
||||
# key is axis name
|
||||
# value is a list of global ranks in same axis with rank 0
|
||||
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
|
||||
Give a global rank and return all global ranks involved in its associated process group in each axis.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
sphysical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
|
||||
# logical mesh will look like
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
print(device_mesh.collate_global_ranks_in_same_process_group(0))
|
||||
|
||||
# key is axis name
|
||||
# value is a list of global ranks in same axis with rank 0
|
||||
# output will look like
|
||||
# {
|
||||
0: [0, 4, 8, 12],
|
||||
1: [0, 1, 2, 3]
|
||||
# }
|
||||
'''
|
||||
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
|
||||
process_groups = {}
|
||||
for dim, logical_ranks in logical_process_groups.items():
|
||||
process_groups[dim] = []
|
||||
for logical_rank in logical_ranks:
|
||||
for g_rank, l_rank in self.convert_map.items():
|
||||
if l_rank == logical_rank:
|
||||
process_groups[dim].append(g_rank)
|
||||
return process_groups
|
||||
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
|
||||
# for self._global_to_local_rank_mapping
|
||||
# the key is the global rank
|
||||
# the value is the list of local ranks corresponding to the global rank with respect of different axes
|
||||
# we can see the list of local ranks as the process coordinates for simplicity
|
||||
# the key and value are all unique, therefore,
|
||||
# we can also to use the coordinates to find the global rank
|
||||
|
||||
# =========================================================================
|
||||
# Step 1
|
||||
# find all the process_coordinates for processes in the same process group
|
||||
# as the given global rank
|
||||
# =========================================================================
|
||||
|
||||
# each
|
||||
processes_in_the_same_process_group = {}
|
||||
|
||||
for dim in range(self.logical_mesh_id.dim()):
|
||||
# iterate over the dimension size so that we can include all processes
|
||||
# in the same process group in the given axis
|
||||
# the _local_rank refers to the local rank of the current process
|
||||
for _local_rank in range(self.logical_mesh_id.shape[dim]):
|
||||
|
||||
# if this dimension is not initailized yet,
|
||||
# initialize it with an empty array
|
||||
if dim not in processes_in_the_same_process_group:
|
||||
processes_in_the_same_process_group[dim] = []
|
||||
|
||||
# get the local rank corresponding to the global rank
|
||||
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
|
||||
|
||||
# replace the local rank in the given dimension with the
|
||||
# lcoal rank of the current process iterated
|
||||
process_coordinates[dim] = _local_rank
|
||||
processes_in_the_same_process_group[dim].append(process_coordinates)
|
||||
|
||||
# =================================================================
|
||||
# Step 2
|
||||
# Use local rank combination to find its corresponding global rank
|
||||
# =================================================================
|
||||
# the key of the dict is the axis
|
||||
# the value is the list of global ranks which are in the same process group as the given global rank
|
||||
global_pg_ranks = {}
|
||||
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
|
||||
global_pg_ranks[dim] = []
|
||||
for process_coordinates in coordinates_of_all_processes:
|
||||
# find the global rank by local rank combination
|
||||
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
|
||||
if process_coordinates == _process_coordinates:
|
||||
global_pg_ranks[dim].append(_global_rank)
|
||||
return global_pg_ranks
|
||||
|
||||
def flatten(self):
|
||||
"""
|
||||
Flatten the logical mesh into an effective 1d logical mesh,
|
||||
"""
|
||||
if self._is_init_from_process_group:
|
||||
raise RuntimeError(
|
||||
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
|
||||
)
|
||||
|
||||
flatten_mesh_shape_size = len(self._mesh_shape)
|
||||
flatten_mesh_shape = [self.num_devices]
|
||||
return DeviceMesh(self._physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self._init_process_group)
|
||||
|
||||
def all_gather_cost(self, num_bytes, mesh_dim):
|
||||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
|
@ -211,39 +503,4 @@ class DeviceMesh:
|
|||
num_devices = self.logical_mesh_id.shape[mesh_dim]
|
||||
penalty_factor = num_devices / 2.0
|
||||
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
||||
|
||||
|
||||
class FlattenDeviceMesh(DeviceMesh):
|
||||
|
||||
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
|
||||
super().__init__(physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha,
|
||||
mesh_beta,
|
||||
init_process_group=False,
|
||||
need_flatten=False)
|
||||
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
|
||||
self.mesh_alpha = max(self.mesh_alpha)
|
||||
self.mesh_beta = min(self.mesh_beta)
|
||||
# Different from original process_groups_dict, rank_list is not stored
|
||||
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
|
||||
|
||||
def create_process_numbers_for_logical_mesh(self):
|
||||
'''
|
||||
Build 1d DeviceMesh in column-major(0) and row-major(1)
|
||||
for example:
|
||||
mesh_shape = (2,4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
process_numbers_dict = {}
|
||||
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
|
||||
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
|
||||
return process_numbers_dict
|
||||
|
||||
def mix_gather_cost(self, num_bytes):
|
||||
num_devices = reduce(operator.mul, self.mesh_shape, 1)
|
||||
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
|
||||
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
|
|
@ -1,6 +1,10 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def test_device_mesh():
|
||||
|
@ -18,5 +22,70 @@ def test_device_mesh():
|
|||
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def check_1d_device_mesh():
|
||||
# check for 1D device mesh
|
||||
process_group = dist.GroupMember.WORLD
|
||||
device_mesh = DeviceMesh.from_process_group(process_group)
|
||||
|
||||
# checks
|
||||
assert device_mesh.shape == [4]
|
||||
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict'
|
||||
assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group'
|
||||
assert device_mesh.is_initialized
|
||||
assert device_mesh.num_devices == 4
|
||||
assert device_mesh.is_initialized
|
||||
assert device_mesh.logical_mesh_id is None
|
||||
assert device_mesh._is_init_from_process_group
|
||||
|
||||
|
||||
def check_2d_device_mesh():
|
||||
# create process group for 2D device mesh
|
||||
first_row_ranks = [0, 1]
|
||||
second_row_ranks = [2, 3]
|
||||
first_col_ranks = [0, 2]
|
||||
second_col_ranks = [1, 3]
|
||||
|
||||
first_row_pg = dist.new_group(first_row_ranks, backend='nccl')
|
||||
second_row_pg = dist.new_group(second_row_ranks, backend='nccl')
|
||||
first_col_pg = dist.new_group(first_col_ranks, backend='nccl')
|
||||
second_col_pg = dist.new_group(second_col_ranks, backend='nccl')
|
||||
|
||||
# check for
|
||||
current_rank = dist.get_rank()
|
||||
|
||||
if current_rank in first_row_ranks:
|
||||
row_pg = first_row_pg
|
||||
else:
|
||||
row_pg = second_row_pg
|
||||
|
||||
if current_rank in first_col_ranks:
|
||||
col_pg = first_col_pg
|
||||
else:
|
||||
col_pg = second_col_pg
|
||||
|
||||
device_mesh = DeviceMesh.from_process_group([col_pg, row_pg])
|
||||
|
||||
# checks
|
||||
assert device_mesh.shape == [2, 2]
|
||||
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict'
|
||||
assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group'
|
||||
assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group'
|
||||
assert device_mesh.num_devices == 4
|
||||
assert device_mesh.is_initialized
|
||||
assert device_mesh.logical_mesh_id is None
|
||||
assert device_mesh._is_init_from_process_group
|
||||
|
||||
|
||||
def check_init_from_process_group(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_device_mesh_from_process_group():
|
||||
spawn(check_init_from_process_group, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_device_mesh()
|
||||
test_device_mesh_from_process_group()
|
||||
|
|
Loading…
Reference in New Issue