You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/cluster/process_group_manager.py

76 lines
2.3 KiB

from typing import List
import torch.distributed as dist
from torch.distributed import ProcessGroup
class ProcessGroupManager:
"""
ProcessGroupManager is used to manage the process groups in the cluster.
There are some terms used in this class:
- pg: the short name for process group
- pg_name: the name of the process group
- pg_size: the world size of the process group
- rank: the rank of the current process in the process group
- world_size: the total number of processes in the process group
"""
def __init__(self):
self.pg_store = dict()
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
Args:
name (str): name of the process group
ranks (List[int]): ranks of the process group
backend (str, optional): backend of the process group. Defaults to 'nccl'.
Returns:
ProcessGroup: the process group
"""
if name not in self.pg_store:
pg = dist.new_group(ranks=ranks, backend=backend)
self.pg_store[name] = pg
return pg
else:
raise ValueError(f'Process group {name} already exists.')
def get(self, name: str) -> ProcessGroup:
"""
Get a process group by name.
Args:
name (str): name of the process group
Returns:
ProcessGroup: the process group
"""
if name in self.pg_store:
return self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
def destroy(self, name: str) -> None:
"""
Destroy a process group by name.
Args:
name (str): name of the process group
"""
if name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
def destroy_all(self) -> None:
"""
Destroy all process groups.
"""
for name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
self.pg_store.clear()