[device] support init device mesh from process group (#3990)

pull/4157/head
Frank Lee 2023-06-15 11:49:16 +08:00
parent a2f9af810d
commit 611971248c
2 changed files with 475 additions and 149 deletions

View File

@ -3,11 +3,19 @@
with some changes. """
import operator
from dataclasses import dataclass
from functools import reduce
from typing import List, Tuple
from typing import Dict, List, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@dataclass
class ProcessGroupContainer:
process_group: ProcessGroup
ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
@ -27,9 +35,11 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
@ -37,160 +47,442 @@ class DeviceMesh:
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
device: str = 'cuda'):
# ============================
# Physical & Logical Mesh IDs
# ============================
self._physical_mesh_id = physical_mesh_id
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
self._mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
self._mesh_shape = self._logical_mesh_id.shape
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model
# alpha is latency and beta is bandwidth
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
mesh_alpha = [1] * len(self._mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
mesh_beta = [1] * len(self._mesh_shape)
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
self.init_process_group = init_process_group
self.need_flatten = need_flatten
if self.init_process_group:
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)
# ensure the alpha and beta have the same shape
assert len(self.mesh_alpha) == len(self.mesh_beta), \
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
# =========================
# Device for Process Group
# =========================
self._device = device
self._dist_backend = self._DIST_BACKEND[device]
# =========================
# Process Group Management
# =========================
# the _global_to_local_rank_mapping is structured as follows
# {
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
tensor=self.logical_mesh_id)
# create process group
self._process_group_dict = {}
self._ranks_in_the_process_group = {}
self._global_rank_of_current_process = None
self._is_initialized = False
# attribute used to inidicate whether this objectd
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
# is known if created with from_process_group
self._is_init_from_process_group = False
# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
if init_process_group:
self.init_logical_process_group()
@property
def shape(self):
return self.mesh_shape
def shape(self) -> torch.Size:
"""
Return the shape of the logical mesh.
"""
return self._mesh_shape
@property
def num_devices(self):
return reduce(operator.mul, self.physical_mesh_id.shape, 1)
def num_devices(self) -> int:
"""
Return the number of devices contained in the device mesh.
"""
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property
def logical_mesh_id(self):
def logical_mesh_id(self) -> torch.Tensor:
"""
Return the logical mesh id.
"""
return self._logical_mesh_id
def __deepcopy__(self, memo):
@property
def is_initialized(self) -> bool:
"""
Return whether the process group is initialized.
"""
return self._is_initialized
@staticmethod
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
"""
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
Args:
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
Returns:
DeviceMesh: the device mesh instance.
"""
def _get_device_by_backend(process_group):
"""
Get the device type given a process group's backend.
"""
backend = dist.get_backend(process_group)
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
if _backend == backend:
return _device
return None
if isinstance(process_group, ProcessGroup):
process_group = [process_group]
# get mesh shape
mesh_shape = [dist.get_world_size(pg) for pg in process_group]
# get device
device_list = [_get_device_by_backend(pg) for pg in process_group]
# make sure all devices are the same
assert all([device == device_list[0] for device in device_list]), \
"All devices should be the same, please check your input process groups are created with the same distributed backend."
# create a fake physical mesh id
# as we only get the process group associated with the current process,
# we cannot get the global ranks for all processes in the mesh
# therefore, we only use this fake physical mesh id to create the device mesh
# and will remove this fake physical mesh id later
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
# create the device mesh
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
# hack the device attribute
device_mesh._physical_mesh_id = None
device_mesh._logical_mesh_id = None
device_mesh._global_rank_of_current_process = dist.get_rank()
device_mesh._is_initialized = False
device_mesh._process_group_dict = {
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
}
return device_mesh
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
Return the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank][axis]
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
"""
Return the process groups for all axes.
Args:
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank]
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
"""
Return the ranks in the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._ranks_in_the_process_group[global_rank][axis]
def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
if k != '_process_group_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
# thus, we share them directly
setattr(result, k, v)
return result
def flatten(self):
def _init_global_to_logical_rank_mapping(self,
mapping: Dict,
tensor: torch.Tensor,
index_list: List[int] = []) -> Dict[int, List[int]]:
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
def _global_rank_to_logical_rank_map(self, tensor, index_list):
'''
This method is a helper function to build convert_map recursively.
'''
Args:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
index_list (List[int])
Returns:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
The value is a list of integers and each integer represents the local rank in the indexed axis.
"""
for index, inner_tensor in enumerate(tensor):
if inner_tensor.numel() == 1:
self.convert_map[int(inner_tensor)] = index_list + [index]
else:
self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
# index means the local rank in the current axis
# inner_tensor refers to the processes with the same local rank
def create_process_groups_for_logical_mesh(self):
if inner_tensor.numel() == 1:
# if the inner_tensor only has one element, it means that
# it already reaches the last axis
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
mapping[int(inner_tensor)] = index_list + [index]
else:
# we recursively go into the function until we reach the last axis
# meanwhile, we should add the local rank in the current axis in the index_list
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
process_groups_dict = {}
check_duplicate_list = []
global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
# sanity check
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
duplicate_check_list = []
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
process_groups_dict[axis] = []
if process_group not in check_duplicate_list:
check_duplicate_list.append(process_group)
process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
return process_groups_dict
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# skip duplicated process group creation
if ranks_in_same_group in duplicate_check_list:
continue
def global_rank_to_logical_rank(self, rank):
return self.convert_map[rank]
# create the process group
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
def global_rank_to_process_groups_with_logical_rank(self, rank):
'''
Give a global rank and return all logical process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
output:
# key is axis name
# value is a list of logical ranks in same axis with rank 0
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
'''
process_groups = {}
for d in range(self.logical_mesh_id.dim()):
for replacer in range(self.logical_mesh_id.shape[d]):
if d not in process_groups:
process_groups[d] = []
process_group_member = self.convert_map[rank].copy()
process_group_member[d] = replacer
process_groups[d].append(process_group_member)
return process_groups
# keep this process group in the process_groups_dict
for rank in ranks_in_same_group:
if rank not in self._process_group_dict:
self._process_group_dict[rank] = dict()
self._process_group_dict[rank][axis] = pg_handler
def global_rank_to_process_groups_with_global_rank(self, rank):
# update the init flag
# we only allow init for once
self._is_initialized = True
def _init_ranks_in_the_same_group(self):
"""
This method is used to initialize the ranks_in_the_same_group dictionary.
"""
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# create dict for each rank
if global_rank not in self._process_group_dict:
self._ranks_in_the_process_group[global_rank] = dict()
# keep this process group in the process_groups_dict
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
"""
Return the local rank of the given global rank in the logical device mesh.
Args:
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
else:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
'''
Give a global rank and return all process groups of this rank.
for example:
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
output:
# key is axis name
# value is a list of global ranks in same axis with rank 0
{0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
Give a global rank and return all global ranks involved in its associated process group in each axis.
Example:
```python
sphysical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# logical mesh will look like
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.collate_global_ranks_in_same_process_group(0))
# key is axis name
# value is a list of global ranks in same axis with rank 0
# output will look like
# {
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
'''
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
process_groups = {}
for dim, logical_ranks in logical_process_groups.items():
process_groups[dim] = []
for logical_rank in logical_ranks:
for g_rank, l_rank in self.convert_map.items():
if l_rank == logical_rank:
process_groups[dim].append(g_rank)
return process_groups
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
# for self._global_to_local_rank_mapping
# the key is the global rank
# the value is the list of local ranks corresponding to the global rank with respect of different axes
# we can see the list of local ranks as the process coordinates for simplicity
# the key and value are all unique, therefore,
# we can also to use the coordinates to find the global rank
# =========================================================================
# Step 1
# find all the process_coordinates for processes in the same process group
# as the given global rank
# =========================================================================
# each
processes_in_the_same_process_group = {}
for dim in range(self.logical_mesh_id.dim()):
# iterate over the dimension size so that we can include all processes
# in the same process group in the given axis
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):
# if this dimension is not initailized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
# get the local rank corresponding to the global rank
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
# replace the local rank in the given dimension with the
# lcoal rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)
# =================================================================
# Step 2
# Use local rank combination to find its corresponding global rank
# =================================================================
# the key of the dict is the axis
# the value is the list of global ranks which are in the same process group as the given global rank
global_pg_ranks = {}
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
global_pg_ranks[dim] = []
for process_coordinates in coordinates_of_all_processes:
# find the global rank by local rank combination
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
if process_coordinates == _process_coordinates:
global_pg_ranks[dim].append(_global_rank)
return global_pg_ranks
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group)
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
@ -211,39 +503,4 @@ class DeviceMesh:
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)

View File

@ -1,6 +1,10 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_device_mesh():
@ -18,5 +22,70 @@ def test_device_mesh():
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
def check_1d_device_mesh():
# check for 1D device mesh
process_group = dist.GroupMember.WORLD
device_mesh = DeviceMesh.from_process_group(process_group)
# checks
assert device_mesh.shape == [4]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict'
assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group'
assert device_mesh.is_initialized
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group
def check_2d_device_mesh():
# create process group for 2D device mesh
first_row_ranks = [0, 1]
second_row_ranks = [2, 3]
first_col_ranks = [0, 2]
second_col_ranks = [1, 3]
first_row_pg = dist.new_group(first_row_ranks, backend='nccl')
second_row_pg = dist.new_group(second_row_ranks, backend='nccl')
first_col_pg = dist.new_group(first_col_ranks, backend='nccl')
second_col_pg = dist.new_group(second_col_ranks, backend='nccl')
# check for
current_rank = dist.get_rank()
if current_rank in first_row_ranks:
row_pg = first_row_pg
else:
row_pg = second_row_pg
if current_rank in first_col_ranks:
col_pg = first_col_pg
else:
col_pg = second_col_pg
device_mesh = DeviceMesh.from_process_group([col_pg, row_pg])
# checks
assert device_mesh.shape == [2, 2]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict'
assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group'
assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group'
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group
def check_init_from_process_group(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_device_mesh_from_process_group():
spawn(check_init_from_process_group, 4)
if __name__ == '__main__':
test_device_mesh()
test_device_mesh_from_process_group()