from typing import List, Optional import torch from colossalai.context.singleton_meta import SingletonMeta from colossalai.logging import get_dist_logger class PyTorchProcessGroupDict(metaclass=SingletonMeta): def __init__(self): # distributed settings # use this dict to record all Pytorch ProcessGroups self.dict = {} # set a distributed logger self.logger = get_dist_logger("ProcessGroup") def log_pg_init(self, rank_list: List[int], backend: str): str_list = ["Pytorch ProcessGroup Init:"] str_list.append(f"backend: {backend}") str_list.append(f"ranks: {rank_list}") self.logger.info("\n\t".join(str_list), ranks=[0]) def get(self, rank_list: List[int], backend: str = "nccl"): """Reuse Pytorch ProcessGroup when such a group is initialized""" # we need to convert the passed list to a tuple # since List is unhashable processgroup_key = (backend, tuple(rank_list)) if processgroup_key not in self.dict: self.log_pg_init(rank_list=rank_list, backend=backend) self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) return self.dict[processgroup_key] PYTORCHPGDICT_ = PyTorchProcessGroupDict() class ProcessGroup: """ProcessGroup Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism. NOTE, the ProcessGroup must be used after `torch.distributed.initialize()` Args: rank: the global rank of the current process. ranks: List[int], a list of rank id belongings to this process group. backend: str, the backend of the process group. tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1. dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). """ def __init__( self, rank: Optional[int] = None, ranks: Optional[List[int]] = None, tp_degree: Optional[int] = None, dp_degree: Optional[int] = None, ) -> None: if not torch.distributed.is_initialized(): self.is_init = False return assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" self._rank = torch.distributed.get_rank() if rank is not None: assert self._rank == rank # make sure that the global rank is correct if ranks is None: self._rank_list = list(range(torch.distributed.get_world_size())) else: self._rank_list = ranks self._rank_list.sort() # ensure that the list is in order self._world_size = len(self._rank_list) if dp_degree is None and tp_degree is None: self._dp_degree = self._world_size self._tp_degree = 1 elif dp_degree and not tp_degree: self._dp_degree = dp_degree assert ( self._world_size % self._dp_degree == 0 ), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" self._tp_degree = self._world_size // dp_degree elif not dp_degree and tp_degree: self._tp_degree = tp_degree assert ( self._world_size % self._tp_degree == 0 ), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree else: self._dp_degree = dp_degree self._tp_degree = tp_degree assert self._dp_degree * self._tp_degree == self._world_size, ( f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" f"and TP degree {self._tp_degree}" ) self._tp_rank_list = None self._dp_rank_list = None for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] PYTORCHPGDICT_.get(i_tp_list, "nccl") if self._rank in i_tp_list: self._tp_rank_list = i_tp_list for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] PYTORCHPGDICT_.get(j_dp_list, "nccl") if self._rank in j_dp_list: self._dp_rank_list = j_dp_list self._has_cpu_groups = False self.is_init = True def set_cpu_groups(self): """set_cpu_groups Initialize Pytorch process groups for cpu communications. """ if self.has_cpu_groups: return for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] PYTORCHPGDICT_.get(i_tp_list, "gloo") for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] PYTORCHPGDICT_.get(j_dp_list, "gloo") self._has_cpu_groups = True @property def has_cpu_groups(self) -> bool: """has_cpu_groups If cpu groups have been initialized. Returns: bool: cpu process groups have been initialized or not. """ return self._has_cpu_groups def __repr__(self): if self.is_init: ranks_str = f"ProcessGroup(ranks={self._rank_list},\n" personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})" return ranks_str + personal_str else: return "ProcessGroup not initialized" def __eq__(self, obj: "ProcessGroup") -> bool: if not isinstance(obj, ProcessGroup): return False if self._rank != obj._rank: return False if self._rank_list != obj._rank_list: return False if self._tp_rank_list != obj._tp_rank_list: return False if self._dp_rank_list != obj._dp_rank_list: return False if self._tp_degree != obj._tp_degree: return False if self._dp_degree != obj._dp_degree: return False return True def rank(self) -> int: """rank The current rank in the global process group. Returns: int: the rank number """ return self._rank def ranks_in_group(self) -> List[int]: """ranks_in_group a list of rank number in in the global process group. Returns: List[int]: a list of rank number. """ return self._rank_list def world_size(self) -> int: """world_size The world size of the global process group. Returns: int: world size """ return self._world_size def tp_rank_list(self) -> List[int]: """tp_rank_list the rank list in the TP process group containing the current rank. Returns: List[int]: the list of rank number. """ return self._tp_rank_list def dp_rank_list(self) -> List[int]: """dp_rank_list the rank list in the DP process group containing the current rank. Returns: List[int]: the list of rank number. """ return self._dp_rank_list def tp_local_rank(self) -> int: """tp_local_rank The local rank number in the current TP process group. Returns: int: tp rank number. """ return self._rank % self._tp_degree def dp_local_rank(self) -> int: """dp_local_rank The local rank number in the current DP process group. Returns: int: dp rank number. """ return self._rank // self._tp_degree def dp_world_size(self) -> int: """dp_world_size The world size of the current DP process group. Returns: int: dp world size """ return len(self._dp_rank_list) def tp_world_size(self) -> int: """tp_world_size The world size of the current TP process group. Returns: int: tp world size """ return len(self._tp_rank_list) def dp_process_group(self): """dp_process_group the pytorch DP process group containing the current rank. Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl") def tp_process_group(self): """tp_process_group the pytorch TP process group containing the current rank. Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl") def cpu_dp_process_group(self): """cpu_dp_process_group the pytorch CPU DP process group containing the current rank. assert failed if cpu process group is not initialized. Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ assert self._has_cpu_groups return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo") def cpu_tp_process_group(self): """cpu_tp_process_group the pytorch CPU TP process group containing the current rank. assert failed if cpu process group is not initialized. Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ assert self._has_cpu_groups return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo") def get_ranks_in_dp(self) -> List[int]: """get_ranks_in_dp ranks in current dp process group. Returns: List[int]: a list of rank number. """ return self._dp_rank_list def get_ranks_in_tp(self): """get_ranks_in_tp ranks in current tp process group. Returns: List[int]: a list of rank number. """ return self._tp_rank_list