mirror of https://github.com/hpcaitech/ColossalAI
[booster] implemented the cluster module (#3191)
* [booster] implemented the cluster module * polish codepull/3199/head
parent
019a847432
commit
e3ad88fb48
|
@ -0,0 +1,5 @@
|
|||
from .device_mesh_manager import DeviceMeshManager
|
||||
from .dist_coordinator import DistCoordinator
|
||||
from .process_group_manager import ProcessGroupManager
|
||||
|
||||
__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
|
|
@ -0,0 +1,36 @@
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class DeviceMeshManager:
|
||||
"""
|
||||
Device mesh manager is responsible for creating and managing device meshes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.device_mesh_store = dict()
|
||||
|
||||
def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
|
||||
"""
|
||||
Create a device mesh and store it in the manager.
|
||||
|
||||
Args:
|
||||
name (str): name of the device mesh
|
||||
*args: args for DeviceMesh
|
||||
**kwargs: kwargs for DeviceMesh
|
||||
"""
|
||||
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
|
||||
if name not in self.device_mesh_store:
|
||||
device_mesh = DeviceMesh(*args, **kwargs)
|
||||
self.device_mesh_store[name] = device_mesh
|
||||
return device_mesh
|
||||
else:
|
||||
raise ValueError(f'Device mesh {name} already exists.')
|
||||
|
||||
def get(self, name: str) -> DeviceMesh:
|
||||
pass
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def destroy_all(self):
|
||||
pass
|
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
|
||||
class DistCoordinator(metaclass=SingletonMeta):
|
||||
"""
|
||||
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
|
||||
class in the whole program.
|
||||
|
||||
There are some terms that are used in this class:
|
||||
- rank: the rank of the current process
|
||||
- world size: the total number of processes
|
||||
- local rank: the rank of the current process on the current node
|
||||
- master: the process with rank 0
|
||||
- node master: the process with local rank 0 on the current node
|
||||
|
||||
Example:
|
||||
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
>>> coordinator = DistCoordinator()
|
||||
>>>
|
||||
>>> if coordinator.is_master():
|
||||
>>> do_something()
|
||||
>>>
|
||||
>>> coordinator.print_on_master('hello world')
|
||||
|
||||
Attributes:
|
||||
rank (int): the rank of the current process
|
||||
world_size (int): the total number of processes
|
||||
local_rank (int): the rank of the current process on the current node
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert dist.is_initialized(
|
||||
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
|
||||
self._rank = dist.get_rank()
|
||||
self._world_size = dist.get_world_size()
|
||||
# this is often passed by launchers such as torchrun
|
||||
self._local_rank = os.environ.get('LOCAL_RANK', -1)
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def local_rank(self) -> int:
|
||||
return self._local_rank
|
||||
|
||||
def _assert_local_rank_set(self):
|
||||
"""
|
||||
Assert that the local rank is set. This is often passed by launchers such as torchrun.
|
||||
"""
|
||||
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
|
||||
|
||||
def is_master(self, process_group: ProcessGroup = None) -> bool:
|
||||
"""
|
||||
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
|
||||
|
||||
Args:
|
||||
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||||
|
||||
Returns:
|
||||
bool: True if the current process is the master process, False otherwise
|
||||
"""
|
||||
rank = dist.get_rank(group=process_group)
|
||||
return rank == 0
|
||||
|
||||
def is_node_master(self) -> bool:
|
||||
"""
|
||||
Check if the current process is the master process on the current node (local rank is 0).
|
||||
|
||||
Returns:
|
||||
bool: True if the current process is the master process on the current node, False otherwise
|
||||
"""
|
||||
self._assert_local_rank_set()
|
||||
return self.local_rank == 0
|
||||
|
||||
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
|
||||
"""
|
||||
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
|
||||
|
||||
Args:
|
||||
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
|
||||
|
||||
Returns:
|
||||
bool: True if the current process is the last process, False otherwise
|
||||
"""
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
return rank == world_size - 1
|
||||
|
||||
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
|
||||
"""
|
||||
Print message only from rank 0.
|
||||
|
||||
Args:
|
||||
msg (str): message to print
|
||||
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||||
"""
|
||||
rank = dist.get_rank(group=process_group)
|
||||
if rank == 0:
|
||||
print(msg)
|
||||
|
||||
def print_on_node_master(self, msg: str):
|
||||
"""
|
||||
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
|
||||
|
||||
Args:
|
||||
msg (str): message to print
|
||||
"""
|
||||
self._assert_local_rank_set()
|
||||
if self.local_rank == 0:
|
||||
print(msg)
|
||||
|
||||
@contextmanager
|
||||
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
as we only want to download in one process to prevent file corruption.
|
||||
|
||||
Example:
|
||||
>>> from colossalai.cluster import DistCoordinator
|
||||
>>> dist_coordinator = DistCoordinator()
|
||||
>>> with dist_coordinator.priority_execution():
|
||||
>>> dataset = CIFAR10(root='./data', download=True)
|
||||
|
||||
Args:
|
||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
|
||||
"""
|
||||
rank = dist.get_rank(group=process_group)
|
||||
should_block = rank != executor_rank
|
||||
|
||||
if should_block:
|
||||
dist.barrier(group=process_group)
|
||||
|
||||
yield
|
||||
|
||||
if not should_block:
|
||||
dist.barrier(group=process_group)
|
||||
|
||||
def destroy(self, process_group: ProcessGroup = None):
|
||||
"""
|
||||
Destroy the distributed process group.
|
||||
|
||||
Args:
|
||||
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
||||
"""
|
||||
dist.destroy_process_group(process_group)
|
|
@ -0,0 +1,75 @@
|
|||
from typing import List
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class ProcessGroupManager:
|
||||
"""
|
||||
ProcessGroupManager is used to manage the process groups in the cluster.
|
||||
|
||||
There are some terms used in this class:
|
||||
- pg: the short name for process group
|
||||
- pg_name: the name of the process group
|
||||
- pg_size: the world size of the process group
|
||||
- rank: the rank of the current process in the process group
|
||||
- world_size: the total number of processes in the process group
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.pg_store = dict()
|
||||
|
||||
def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
|
||||
"""
|
||||
Get a process group by name. If the process group does not exist, it will be created.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
ranks (List[int]): ranks of the process group
|
||||
backend (str, optional): backend of the process group. Defaults to 'nccl'.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: the process group
|
||||
"""
|
||||
if name not in self.pg_store:
|
||||
pg = dist.new_group(ranks=ranks, backend=backend)
|
||||
self.pg_store[name] = pg
|
||||
return pg
|
||||
else:
|
||||
raise ValueError(f'Process group {name} already exists.')
|
||||
|
||||
def get(self, name: str) -> ProcessGroup:
|
||||
"""
|
||||
Get a process group by name.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
|
||||
Returns:
|
||||
ProcessGroup: the process group
|
||||
"""
|
||||
if name in self.pg_store:
|
||||
return self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
|
||||
def destroy(self, name: str) -> None:
|
||||
"""
|
||||
Destroy a process group by name.
|
||||
|
||||
Args:
|
||||
name (str): name of the process group
|
||||
"""
|
||||
if name in self.pg_store:
|
||||
dist.destroy_process_group(self.pg_store[name])
|
||||
del self.pg_store[name]
|
||||
else:
|
||||
raise ValueError(f'Process group {name} does not exist.')
|
||||
|
||||
def destroy_all(self) -> None:
|
||||
"""
|
||||
Destroy all process groups.
|
||||
"""
|
||||
for name in self.pg_store:
|
||||
dist.destroy_process_group(self.pg_store[name])
|
||||
self.pg_store.clear()
|
Loading…
Reference in New Issue