"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes. """ import operator from dataclasses import dataclass from functools import reduce from typing import Dict, List, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup @dataclass class ProcessGroupContainer: process_group: ProcessGroup ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) class DeviceMesh: """A logical view of a physical cluster. For example, we could view a physical cluster with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4). Arguments: physical_mesh_id (torch.Tensor): physical view of the devices in global rank. logical_mesh_id (torch.Tensor): logical view of the devices in global rank. mesh_shape (torch.Size, optional): shape of logical view. mesh_alpha (List[float], optional): coefficients used for computing communication cost (default: None) mesh_beta (List[float], optional): coefficients used for computing communication cost (default: None) init_process_group (bool, optional): initialize logical process group during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, logical_mesh_id: torch.Tensor = None, mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, device: str = 'cuda'): # ============================ # Physical & Logical Mesh IDs # ============================ self._physical_mesh_id = physical_mesh_id assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." # logical mesh ids can be obtained via two ways # 1. provide physical mesh id and provide mesh shape # 2. directly supply the logical mesh id assert mesh_shape is None or logical_mesh_id is None, \ "Only one of mesh_shape and logical_mesh_id can be specified." \ "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" if logical_mesh_id is None: self.mesh_shape = mesh_shape self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape) else: self._logical_mesh_id = logical_mesh_id self.mesh_shape = self._logical_mesh_id.shape # ensure two things: # 1. logical and physical mesh IDs should contain the same elements # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." # =============================================== # coefficient for alpha-beta communication model # alpha is latency and beta is bandwidth # =============================================== # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: mesh_alpha = [1] * len(self.mesh_shape) if mesh_beta is None: mesh_beta = [1] * len(self.mesh_shape) self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) # ensure the alpha and beta have the same shape assert len(self.mesh_alpha) == len(self.mesh_beta), \ "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." # ========================= # Device for Process Group # ========================= self._device = device self._dist_backend = self._DIST_BACKEND[device] # ========================= # Process Group Management # ========================= # the _global_to_local_rank_mapping is structured as follows # { # : [ , , , ...] # } self._global_to_local_rank_mapping = dict() self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id) # create process group self._process_group_dict = {} self._ranks_in_the_process_group = {} self._global_rank_of_current_process = None self._is_initialized = False # initialize process group if specified self._init_ranks_in_the_same_group() self._init_process_group = init_process_group if init_process_group: self.init_logical_process_group() @property def shape(self) -> torch.Size: """ Return the shape of the logical mesh. """ return self.mesh_shape @property def num_devices(self) -> int: """ Return the number of devices contained in the device mesh. """ return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property def logical_mesh_id(self) -> torch.Tensor: """ Return the logical mesh id. """ return self._logical_mesh_id def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: """ Return the process group on the specified axis. Args: axis (int): the axis of the process group. global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) """ if global_rank is None: global_rank = self._global_rank_of_current_process return self._process_group_dict[global_rank][axis] def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: """ Return the process groups for all axes. Args: global_rank (int, optional): the global rank of the process """ if global_rank is None: global_rank = self._global_rank_of_current_process return self._process_group_dict[global_rank] def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: """ Return the ranks in the process group on the specified axis. Args: axis (int): the axis of the process group. global_rank (int, optional): the global rank of the process """ if global_rank is None: global_rank = self._global_rank_of_current_process return self._ranks_in_the_process_group[global_rank][axis] def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): if k != 'process_groups_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: # process group cannot be copied # thus, we share them directly setattr(result, k, v) return result def _init_global_to_logical_rank_mapping(self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []) -> Dict[int, List[int]]: """ Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. Args: mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. tensor (torch.Tensor): the tensor that contains the logical mesh ids. index_list (List[int]) Returns: mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. The value is a list of integers and each integer represents the local rank in the indexed axis. """ for index, inner_tensor in enumerate(tensor): # index means the local rank in the current axis # inner_tensor refers to the processes with the same local rank if inner_tensor.numel() == 1: # if the inner_tensor only has one element, it means that # it already reaches the last axis # we append its local_rank in the last axis to the index_list # and assign to the mapping # the value of the mapping is the the local rank at the indexed axis of the device mesh mapping[int(inner_tensor)] = index_list + [index] else: # we recursively go into the function until we reach the last axis # meanwhile, we should add the local rank in the current axis in the index_list self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) def init_logical_process_group(self): ''' This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' # sanity check assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" # update the global rank of the current process self._global_rank_of_current_process = dist.get_rank() duplicate_check_list = [] # flatten the global ranks to 1D list global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() for global_rank in global_rank_flatten_list: # find the other ranks which are in the same process group as global_rank ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): # skip duplicated process group creation if ranks_in_same_group in duplicate_check_list: continue # create the process group pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) # keep this process group in the process_groups_dict for rank in ranks_in_same_group: if rank not in self._process_group_dict: self._process_group_dict[rank] = dict() self._process_group_dict[rank][axis] = pg_handler # update the init flag # we only allow init for once self._is_initialized = True def _init_ranks_in_the_same_group(self): """ This method is used to initialize the ranks_in_the_same_group dictionary. """ # flatten the global ranks to 1D list global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() for global_rank in global_rank_flatten_list: # find the other ranks which are in the same process group as global_rank ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): # create dict for each rank if global_rank not in self._process_group_dict: self._ranks_in_the_process_group[global_rank] = dict() # keep this process group in the process_groups_dict self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: """ Return the local rank of the given global rank in the logical device mesh. Args: rank (int): the global rank in the logical device mesh. axis (int): the axis of the logical device mesh. """ local_ranks = self._global_to_local_rank_mapping[rank] if axis: return local_ranks[axis] else: return local_ranks def _collate_global_ranks_in_same_process_group(self, global_rank): ''' Give a global rank and return all global ranks involved in its associated process group in each axis. Example: ```python sphysical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # logical mesh will look like # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) print(device_mesh.collate_global_ranks_in_same_process_group(0)) # key is axis name # value is a list of global ranks in same axis with rank 0 # output will look like # { 0: [0, 4, 8, 12], 1: [0, 1, 2, 3] # } ''' # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping # for self._global_to_local_rank_mapping # the key is the global rank # the value is the list of local ranks corresponding to the global rank with respect of different axes # we can see the list of local ranks as the process coordinates for simplicity # the key and value are all unique, therefore, # we can also to use the coordinates to find the global rank # ========================================================================= # Step 1 # find all the process_coordinates for processes in the same process group # as the given global rank # ========================================================================= # each processes_in_the_same_process_group = {} for dim in range(self.logical_mesh_id.dim()): # iterate over the dimension size so that we can include all processes # in the same process group in the given axis # the _local_rank refers to the local rank of the current process for _local_rank in range(self.logical_mesh_id.shape[dim]): # if this dimension is not initailized yet, # initialize it with an empty array if dim not in processes_in_the_same_process_group: processes_in_the_same_process_group[dim] = [] # get the local rank corresponding to the global rank process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() # replace the local rank in the given dimension with the # lcoal rank of the current process iterated process_coordinates[dim] = _local_rank processes_in_the_same_process_group[dim].append(process_coordinates) # ================================================================= # Step 2 # Use local rank combination to find its corresponding global rank # ================================================================= # the key of the dict is the axis # the value is the list of global ranks which are in the same process group as the given global rank global_pg_ranks = {} for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): global_pg_ranks[dim] = [] for process_coordinates in coordinates_of_all_processes: # find the global rank by local rank combination for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): if process_coordinates == _process_coordinates: global_pg_ranks[dim].append(_global_rank) return global_pg_ranks def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, """ flatten_mesh_shape_size = len(self.mesh_shape) flatten_mesh_shape = [self.num_devices] return DeviceMesh(self._physical_mesh_id, tuple(flatten_mesh_shape), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), init_process_group=self._init_process_group) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1) def all_reduce_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + 0.01) def reduce_scatter_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001) def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)