diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 2a5f747fb..3e96310e1 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,11 +3,19 @@ with some changes. """ import operator +from dataclasses import dataclass from functools import reduce -from typing import List, Tuple +from typing import Dict, List, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup + + +@dataclass +class ProcessGroupContainer: + process_group: ProcessGroup + ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -27,9 +35,11 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, @@ -37,160 +47,442 @@ class DeviceMesh: mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, - need_flatten: bool = True): - self.physical_mesh_id = physical_mesh_id + device: str = 'cuda'): + # ============================ + # Physical & Logical Mesh IDs + # ============================ + self._physical_mesh_id = physical_mesh_id + assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." + + # logical mesh ids can be obtained via two ways + # 1. provide physical mesh id and provide mesh shape + # 2. directly supply the logical mesh id + assert mesh_shape is None or logical_mesh_id is None, \ + "Only one of mesh_shape and logical_mesh_id can be specified." \ + "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" + if logical_mesh_id is None: - self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + self._mesh_shape = mesh_shape + self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape) else: self._logical_mesh_id = logical_mesh_id - self.mesh_shape = self._logical_mesh_id.shape + self._mesh_shape = self._logical_mesh_id.shape - # map global rank into logical rank - self.convert_map = {} - self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # ensure two things: + # 1. logical and physical mesh IDs should contain the same elements + # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed + assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ + "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ + "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ + "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." + + # =============================================== # coefficient for alpha-beta communication model + # alpha is latency and beta is bandwidth + # =============================================== + # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: - mesh_alpha = [1] * len(self.mesh_shape) + mesh_alpha = [1] * len(self._mesh_shape) if mesh_beta is None: - mesh_beta = [1] * len(self.mesh_shape) + mesh_beta = [1] * len(self._mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - self.init_process_group = init_process_group - self.need_flatten = need_flatten - if self.init_process_group: - self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten and self._logical_mesh_id.dim() > 1: - self.flatten_device_mesh = self.flatten() - # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - # self.mesh_beta) + + # ensure the alpha and beta have the same shape + assert len(self.mesh_alpha) == len(self.mesh_beta), \ + "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + + # ========================= + # Device for Process Group + # ========================= + self._device = device + self._dist_backend = self._DIST_BACKEND[device] + + # ========================= + # Process Group Management + # ========================= + # the _global_to_local_rank_mapping is structured as follows + # { + # : [ , , , ...] + # } + self._global_to_local_rank_mapping = dict() + self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, + tensor=self.logical_mesh_id) + + # create process group + self._process_group_dict = {} + self._ranks_in_the_process_group = {} + self._global_rank_of_current_process = None + self._is_initialized = False + + # attribute used to inidicate whether this objectd + # is created using DeviceMesh.from_process_group + # this attribute can be used to do some check in methods + # such get_process_group as no global rank information + # is known if created with from_process_group + self._is_init_from_process_group = False + + # initialize process group if specified + self._init_ranks_in_the_same_group() + self._init_process_group = init_process_group + if init_process_group: + self.init_logical_process_group() @property - def shape(self): - return self.mesh_shape + def shape(self) -> torch.Size: + """ + Return the shape of the logical mesh. + """ + return self._mesh_shape @property - def num_devices(self): - return reduce(operator.mul, self.physical_mesh_id.shape, 1) + def num_devices(self) -> int: + """ + Return the number of devices contained in the device mesh. + """ + return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property - def logical_mesh_id(self): + def logical_mesh_id(self) -> torch.Tensor: + """ + Return the logical mesh id. + """ return self._logical_mesh_id - def __deepcopy__(self, memo): + @property + def is_initialized(self) -> bool: + """ + Return whether the process group is initialized. + """ + return self._is_initialized + + @staticmethod + def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh": + """ + Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method + will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication. + + Args: + process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh. + If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects, + the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh. + + Returns: + DeviceMesh: the device mesh instance. + """ + + def _get_device_by_backend(process_group): + """ + Get the device type given a process group's backend. + """ + backend = dist.get_backend(process_group) + for _device, _backend in DeviceMesh._DIST_BACKEND.items(): + if _backend == backend: + return _device + return None + + if isinstance(process_group, ProcessGroup): + process_group = [process_group] + + # get mesh shape + mesh_shape = [dist.get_world_size(pg) for pg in process_group] + + # get device + device_list = [_get_device_by_backend(pg) for pg in process_group] + + # make sure all devices are the same + assert all([device == device_list[0] for device in device_list]), \ + "All devices should be the same, please check your input process groups are created with the same distributed backend." + + # create a fake physical mesh id + # as we only get the process group associated with the current process, + # we cannot get the global ranks for all processes in the mesh + # therefore, we only use this fake physical mesh id to create the device mesh + # and will remove this fake physical mesh id later + fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1)) + + # create the device mesh + device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0]) + + # hack the device attribute + device_mesh._physical_mesh_id = None + device_mesh._logical_mesh_id = None + device_mesh._global_rank_of_current_process = dist.get_rank() + device_mesh._is_initialized = False + device_mesh._process_group_dict = { + device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)} + } + + return device_mesh + + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: + """ + Return the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank][axis] + + def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: + """ + Return the process groups for all axes. + + Args: + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank] + + def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: + """ + Return the ranks in the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._ranks_in_the_process_group[global_rank][axis] + + def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != 'process_groups_dict': + if k != '_process_group_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: + # process group cannot be copied + # thus, we share them directly setattr(result, k, v) - return result - def flatten(self): + def _init_global_to_logical_rank_mapping(self, + mapping: Dict, + tensor: torch.Tensor, + index_list: List[int] = []) -> Dict[int, List[int]]: """ - Flatten the logical mesh into an effective 1d logical mesh, - """ - flatten_mesh_shape_size = len(self.mesh_shape) - flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self.physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self.init_process_group, - need_flatten=False) + Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. - def _global_rank_to_logical_rank_map(self, tensor, index_list): - ''' - This method is a helper function to build convert_map recursively. - ''' + Args: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + tensor (torch.Tensor): the tensor that contains the logical mesh ids. + index_list (List[int]) + + Returns: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + The value is a list of integers and each integer represents the local rank in the indexed axis. + """ for index, inner_tensor in enumerate(tensor): - if inner_tensor.numel() == 1: - self.convert_map[int(inner_tensor)] = index_list + [index] - else: - self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + # index means the local rank in the current axis + # inner_tensor refers to the processes with the same local rank - def create_process_groups_for_logical_mesh(self): + if inner_tensor.numel() == 1: + # if the inner_tensor only has one element, it means that + # it already reaches the last axis + # we append its local_rank in the last axis to the index_list + # and assign to the mapping + # the value of the mapping is the the local rank at the indexed axis of the device mesh + mapping[int(inner_tensor)] = index_list + [index] + else: + # we recursively go into the function until we reach the last axis + # meanwhile, we should add the local rank in the current axis in the index_list + self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) + + def init_logical_process_group(self): ''' This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' - process_groups_dict = {} - check_duplicate_list = [] - global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + # sanity check + assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" + assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + + # update the global rank of the current process + self._global_rank_of_current_process = dist.get_rank() + duplicate_check_list = [] + + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: - process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) - for axis, process_group in process_groups.items(): - if axis not in process_groups_dict: - process_groups_dict[axis] = [] - if process_group not in check_duplicate_list: - check_duplicate_list.append(process_group) - process_group_handler = dist.new_group(process_group) - process_groups_dict[axis].append((process_group, process_group_handler)) + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) - return process_groups_dict + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # skip duplicated process group creation + if ranks_in_same_group in duplicate_check_list: + continue - def global_rank_to_logical_rank(self, rank): - return self.convert_map[rank] + # create the process group + pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) - def global_rank_to_process_groups_with_logical_rank(self, rank): - ''' - Give a global rank and return all logical process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) - output: - # key is axis name - # value is a list of logical ranks in same axis with rank 0 - {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} - ''' - process_groups = {} - for d in range(self.logical_mesh_id.dim()): - for replacer in range(self.logical_mesh_id.shape[d]): - if d not in process_groups: - process_groups[d] = [] - process_group_member = self.convert_map[rank].copy() - process_group_member[d] = replacer - process_groups[d].append(process_group_member) - return process_groups + # keep this process group in the process_groups_dict + for rank in ranks_in_same_group: + if rank not in self._process_group_dict: + self._process_group_dict[rank] = dict() + self._process_group_dict[rank][axis] = pg_handler - def global_rank_to_process_groups_with_global_rank(self, rank): + # update the init flag + # we only allow init for once + self._is_initialized = True + + def _init_ranks_in_the_same_group(self): + """ + This method is used to initialize the ranks_in_the_same_group dictionary. + """ + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # create dict for each rank + if global_rank not in self._process_group_dict: + self._ranks_in_the_process_group[global_rank] = dict() + + # keep this process group in the process_groups_dict + self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + + def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: + """ + Return the local rank of the given global rank in the logical device mesh. + + Args: + rank (int): the global rank in the logical device mesh. + axis (int): the axis of the logical device mesh. + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + local_ranks = self._global_to_local_rank_mapping[rank] + if axis: + return local_ranks[axis] + else: + return local_ranks + + def _collate_global_ranks_in_same_process_group(self, global_rank): ''' - Give a global rank and return all process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) - output: - # key is axis name - # value is a list of global ranks in same axis with rank 0 - {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + Give a global rank and return all global ranks involved in its associated process group in each axis. + + Example: + + ```python + sphysical_mesh_id = torch.arange(0, 16) + mesh_shape = (4, 4) + + # logical mesh will look like + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.collate_global_ranks_in_same_process_group(0)) + + # key is axis name + # value is a list of global ranks in same axis with rank 0 + # output will look like + # { + 0: [0, 4, 8, 12], + 1: [0, 1, 2, 3] + # } ''' - logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) - process_groups = {} - for dim, logical_ranks in logical_process_groups.items(): - process_groups[dim] = [] - for logical_rank in logical_ranks: - for g_rank, l_rank in self.convert_map.items(): - if l_rank == logical_rank: - process_groups[dim].append(g_rank) - return process_groups + # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping + # for self._global_to_local_rank_mapping + # the key is the global rank + # the value is the list of local ranks corresponding to the global rank with respect of different axes + # we can see the list of local ranks as the process coordinates for simplicity + # the key and value are all unique, therefore, + # we can also to use the coordinates to find the global rank + + # ========================================================================= + # Step 1 + # find all the process_coordinates for processes in the same process group + # as the given global rank + # ========================================================================= + + # each + processes_in_the_same_process_group = {} + + for dim in range(self.logical_mesh_id.dim()): + # iterate over the dimension size so that we can include all processes + # in the same process group in the given axis + # the _local_rank refers to the local rank of the current process + for _local_rank in range(self.logical_mesh_id.shape[dim]): + + # if this dimension is not initailized yet, + # initialize it with an empty array + if dim not in processes_in_the_same_process_group: + processes_in_the_same_process_group[dim] = [] + + # get the local rank corresponding to the global rank + process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() + + # replace the local rank in the given dimension with the + # lcoal rank of the current process iterated + process_coordinates[dim] = _local_rank + processes_in_the_same_process_group[dim].append(process_coordinates) + + # ================================================================= + # Step 2 + # Use local rank combination to find its corresponding global rank + # ================================================================= + # the key of the dict is the axis + # the value is the list of global ranks which are in the same process group as the given global rank + global_pg_ranks = {} + for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): + global_pg_ranks[dim] = [] + for process_coordinates in coordinates_of_all_processes: + # find the global rank by local rank combination + for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): + if process_coordinates == _process_coordinates: + global_pg_ranks[dim].append(_global_rank) + return global_pg_ranks + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + flatten_mesh_shape_size = len(self._mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] @@ -211,39 +503,4 @@ class DeviceMesh: num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) - - -class FlattenDeviceMesh(DeviceMesh): - - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): - super().__init__(physical_mesh_id, - mesh_shape, - mesh_alpha, - mesh_beta, - init_process_group=False, - need_flatten=False) - # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars - self.mesh_alpha = max(self.mesh_alpha) - self.mesh_beta = min(self.mesh_beta) - # Different from original process_groups_dict, rank_list is not stored - self.process_number_dict = self.create_process_numbers_for_logical_mesh() - - def create_process_numbers_for_logical_mesh(self): - ''' - Build 1d DeviceMesh in column-major(0) and row-major(1) - for example: - mesh_shape = (2,4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7]] - # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - num_devices = reduce(operator.mul, self.mesh_shape, 1) - process_numbers_dict = {} - process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() - process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() - return process_numbers_dict - - def mix_gather_cost(self, num_bytes): - num_devices = reduce(operator.mul, self.mesh_shape, 1) - return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) \ No newline at end of file diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 789ce8ab3..e9f0f9477 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,10 @@ +import pytest import torch +import torch.distributed as dist +import colossalai from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_device_mesh(): @@ -18,5 +22,70 @@ def test_device_mesh(): assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' + assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend='nccl') + second_row_pg = dist.new_group(second_row_ranks, backend='nccl') + first_col_pg = dist.new_group(first_col_ranks, backend='nccl') + second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' + assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' + assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) + + if __name__ == '__main__': test_device_mesh() + test_device_mesh_from_process_group()