from typing import List import torch.distributed as dist from torch.distributed import ProcessGroup class ProcessGroupManager: """ ProcessGroupManager is used to manage the process groups in the cluster. There are some terms used in this class: - pg: the short name for process group - pg_name: the name of the process group - pg_size: the world size of the process group - rank: the rank of the current process in the process group - world_size: the total number of processes in the process group """ def __init__(self): self.pg_store = dict() def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: """ Get a process group by name. If the process group does not exist, it will be created. Args: name (str): name of the process group ranks (List[int]): ranks of the process group backend (str, optional): backend of the process group. Defaults to 'nccl'. Returns: ProcessGroup: the process group """ if name not in self.pg_store: pg = dist.new_group(ranks=ranks, backend=backend) self.pg_store[name] = pg return pg else: raise ValueError(f'Process group {name} already exists.') def get(self, name: str) -> ProcessGroup: """ Get a process group by name. Args: name (str): name of the process group Returns: ProcessGroup: the process group """ if name in self.pg_store: return self.pg_store[name] else: raise ValueError(f'Process group {name} does not exist.') def destroy(self, name: str) -> None: """ Destroy a process group by name. Args: name (str): name of the process group """ if name in self.pg_store: dist.destroy_process_group(self.pg_store[name]) del self.pg_store[name] else: raise ValueError(f'Process group {name} does not exist.') def destroy_all(self) -> None: """ Destroy all process groups. """ for name in self.pg_store: dist.destroy_process_group(self.pg_store[name]) self.pg_store.clear()