[booster] implemented the cluster module (#3191)

* [booster] implemented the cluster module

* polish code
pull/3199/head
Frank Lee 2023-03-22 14:11:54 +08:00 committed by GitHub
parent 019a847432
commit e3ad88fb48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 274 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']

View File

@ -0,0 +1,36 @@
from colossalai.device.device_mesh import DeviceMesh
class DeviceMeshManager:
"""
Device mesh manager is responsible for creating and managing device meshes.
"""
def __init__(self):
self.device_mesh_store = dict()
def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
"""
Create a device mesh and store it in the manager.
Args:
name (str): name of the device mesh
*args: args for DeviceMesh
**kwargs: kwargs for DeviceMesh
"""
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
if name not in self.device_mesh_store:
device_mesh = DeviceMesh(*args, **kwargs)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
raise ValueError(f'Device mesh {name} already exists.')
def get(self, name: str) -> DeviceMesh:
pass
def destroy(self):
pass
def destroy_all(self):
pass

View File

@ -0,0 +1,158 @@
import os
from contextlib import contextmanager
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context.singleton_meta import SingletonMeta
class DistCoordinator(metaclass=SingletonMeta):
"""
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
class in the whole program.
There are some terms that are used in this class:
- rank: the rank of the current process
- world size: the total number of processes
- local rank: the rank of the current process on the current node
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
Example:
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
>>> coordinator = DistCoordinator()
>>>
>>> if coordinator.is_master():
>>> do_something()
>>>
>>> coordinator.print_on_master('hello world')
Attributes:
rank (int): the rank of the current process
world_size (int): the total number of processes
local_rank (int): the rank of the current process on the current node
"""
def __init__(self):
assert dist.is_initialized(
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = os.environ.get('LOCAL_RANK', -1)
@property
def rank(self) -> int:
return self._rank
@property
def world_size(self) -> int:
return self._world_size
@property
def local_rank(self) -> int:
return self._local_rank
def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the master process, False otherwise
"""
rank = dist.get_rank(group=process_group)
return rank == 0
def is_node_master(self) -> bool:
"""
Check if the current process is the master process on the current node (local rank is 0).
Returns:
bool: True if the current process is the master process on the current node, False otherwise
"""
self._assert_local_rank_set()
return self.local_rank == 0
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the last process, False otherwise
"""
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
return rank == world_size - 1
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
"""
Print message only from rank 0.
Args:
msg (str): message to print
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
if rank == 0:
print(msg)
def print_on_node_master(self, msg: str):
"""
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
Args:
msg (str): message to print
"""
self._assert_local_rank_set()
if self.local_rank == 0:
print(msg)
@contextmanager
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>> with dist_coordinator.priority_execution():
>>> dataset = CIFAR10(root='./data', download=True)
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
should_block = rank != executor_rank
if should_block:
dist.barrier(group=process_group)
yield
if not should_block:
dist.barrier(group=process_group)
def destroy(self, process_group: ProcessGroup = None):
"""
Destroy the distributed process group.
Args:
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
"""
dist.destroy_process_group(process_group)

View File

@ -0,0 +1,75 @@
from typing import List
import torch.distributed as dist
from torch.distributed import ProcessGroup
class ProcessGroupManager:
"""
ProcessGroupManager is used to manage the process groups in the cluster.
There are some terms used in this class:
- pg: the short name for process group
- pg_name: the name of the process group
- pg_size: the world size of the process group
- rank: the rank of the current process in the process group
- world_size: the total number of processes in the process group
"""
def __init__(self):
self.pg_store = dict()
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
Args:
name (str): name of the process group
ranks (List[int]): ranks of the process group
backend (str, optional): backend of the process group. Defaults to 'nccl'.
Returns:
ProcessGroup: the process group
"""
if name not in self.pg_store:
pg = dist.new_group(ranks=ranks, backend=backend)
self.pg_store[name] = pg
return pg
else:
raise ValueError(f'Process group {name} already exists.')
def get(self, name: str) -> ProcessGroup:
"""
Get a process group by name.
Args:
name (str): name of the process group
Returns:
ProcessGroup: the process group
"""
if name in self.pg_store:
return self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
def destroy(self, name: str) -> None:
"""
Destroy a process group by name.
Args:
name (str): name of the process group
"""
if name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')
def destroy_all(self) -> None:
"""
Destroy all process groups.
"""
for name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
self.pg_store.clear()