mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
75 lines
2.3 KiB
from typing import List |
|
|
|
import torch.distributed as dist |
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
class ProcessGroupManager: |
|
""" |
|
ProcessGroupManager is used to manage the process groups in the cluster. |
|
|
|
There are some terms used in this class: |
|
- pg: the short name for process group |
|
- pg_name: the name of the process group |
|
- pg_size: the world size of the process group |
|
- rank: the rank of the current process in the process group |
|
- world_size: the total number of processes in the process group |
|
""" |
|
|
|
def __init__(self): |
|
self.pg_store = dict() |
|
|
|
def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup: |
|
""" |
|
Get a process group by name. If the process group does not exist, it will be created. |
|
|
|
Args: |
|
name (str): name of the process group |
|
ranks (List[int]): ranks of the process group |
|
backend (str, optional): backend of the process group. Defaults to 'nccl'. |
|
|
|
Returns: |
|
ProcessGroup: the process group |
|
""" |
|
if name not in self.pg_store: |
|
pg = dist.new_group(ranks=ranks, backend=backend) |
|
self.pg_store[name] = pg |
|
return pg |
|
else: |
|
raise ValueError(f"Process group {name} already exists.") |
|
|
|
def get(self, name: str) -> ProcessGroup: |
|
""" |
|
Get a process group by name. |
|
|
|
Args: |
|
name (str): name of the process group |
|
|
|
Returns: |
|
ProcessGroup: the process group |
|
""" |
|
if name in self.pg_store: |
|
return self.pg_store[name] |
|
else: |
|
raise ValueError(f"Process group {name} does not exist.") |
|
|
|
def destroy(self, name: str) -> None: |
|
""" |
|
Destroy a process group by name. |
|
|
|
Args: |
|
name (str): name of the process group |
|
""" |
|
if name in self.pg_store: |
|
dist.destroy_process_group(self.pg_store[name]) |
|
del self.pg_store[name] |
|
else: |
|
raise ValueError(f"Process group {name} does not exist.") |
|
|
|
def destroy_all(self) -> None: |
|
""" |
|
Destroy all process groups. |
|
""" |
|
for name in self.pg_store: |
|
dist.destroy_process_group(self.pg_store[name]) |
|
self.pg_store.clear()
|
|
|