From e3ad88fb482fdd95241a1f74866559b83ab4f56b Mon Sep 17 00:00:00 2001 From: Frank Lee <somerlee.9@gmail.com> Date: Wed, 22 Mar 2023 14:11:54 +0800 Subject: [PATCH] [booster] implemented the cluster module (#3191) * [booster] implemented the cluster module * polish code --- colossalai/cluster/__init__.py | 5 + colossalai/cluster/device_mesh_manager.py | 36 +++++ colossalai/cluster/dist_coordinator.py | 158 ++++++++++++++++++++ colossalai/cluster/process_group_manager.py | 75 ++++++++++ 4 files changed, 274 insertions(+) create mode 100644 colossalai/cluster/__init__.py create mode 100644 colossalai/cluster/device_mesh_manager.py create mode 100644 colossalai/cluster/dist_coordinator.py create mode 100644 colossalai/cluster/process_group_manager.py diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py new file mode 100644 index 000000000..2fbdfd3cc --- /dev/null +++ b/colossalai/cluster/__init__.py @@ -0,0 +1,5 @@ +from .device_mesh_manager import DeviceMeshManager +from .dist_coordinator import DistCoordinator +from .process_group_manager import ProcessGroupManager + +__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager'] diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py new file mode 100644 index 000000000..744799182 --- /dev/null +++ b/colossalai/cluster/device_mesh_manager.py @@ -0,0 +1,36 @@ +from colossalai.device.device_mesh import DeviceMesh + + +class DeviceMeshManager: + """ + Device mesh manager is responsible for creating and managing device meshes. + """ + + def __init__(self): + self.device_mesh_store = dict() + + def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh: + """ + Create a device mesh and store it in the manager. + + Args: + name (str): name of the device mesh + *args: args for DeviceMesh + **kwargs: kwargs for DeviceMesh + """ + # TODO(Yuliang): replace *args, **kwargs with explicit arguments + if name not in self.device_mesh_store: + device_mesh = DeviceMesh(*args, **kwargs) + self.device_mesh_store[name] = device_mesh + return device_mesh + else: + raise ValueError(f'Device mesh {name} already exists.') + + def get(self, name: str) -> DeviceMesh: + pass + + def destroy(self): + pass + + def destroy_all(self): + pass diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py new file mode 100644 index 000000000..6b48faf5b --- /dev/null +++ b/colossalai/cluster/dist_coordinator.py @@ -0,0 +1,158 @@ +import os +from contextlib import contextmanager + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.context.singleton_meta import SingletonMeta + + +class DistCoordinator(metaclass=SingletonMeta): + """ + This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this + class in the whole program. + + There are some terms that are used in this class: + - rank: the rank of the current process + - world size: the total number of processes + - local rank: the rank of the current process on the current node + - master: the process with rank 0 + - node master: the process with local rank 0 on the current node + + Example: + >>> from colossalai.cluster.dist_coordinator import DistCoordinator + >>> coordinator = DistCoordinator() + >>> + >>> if coordinator.is_master(): + >>> do_something() + >>> + >>> coordinator.print_on_master('hello world') + + Attributes: + rank (int): the rank of the current process + world_size (int): the total number of processes + local_rank (int): the rank of the current process on the current node + """ + + def __init__(self): + assert dist.is_initialized( + ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' + self._rank = dist.get_rank() + self._world_size = dist.get_world_size() + # this is often passed by launchers such as torchrun + self._local_rank = os.environ.get('LOCAL_RANK', -1) + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + @property + def local_rank(self) -> int: + return self._local_rank + + def _assert_local_rank_set(self): + """ + Assert that the local rank is set. This is often passed by launchers such as torchrun. + """ + assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' + + def is_master(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the master process, False otherwise + """ + rank = dist.get_rank(group=process_group) + return rank == 0 + + def is_node_master(self) -> bool: + """ + Check if the current process is the master process on the current node (local rank is 0). + + Returns: + bool: True if the current process is the master process on the current node, False otherwise + """ + self._assert_local_rank_set() + return self.local_rank == 0 + + def is_last_process(self, process_group: ProcessGroup = None) -> bool: + """ + Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process. + + Args: + process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group. + + Returns: + bool: True if the current process is the last process, False otherwise + """ + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + return rank == world_size - 1 + + def print_on_master(self, msg: str, process_group: ProcessGroup = None): + """ + Print message only from rank 0. + + Args: + msg (str): message to print + process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + if rank == 0: + print(msg) + + def print_on_node_master(self, msg: str): + """ + Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node. + + Args: + msg (str): message to print + """ + self._assert_local_rank_set() + if self.local_rank == 0: + print(msg) + + @contextmanager + def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + + Example: + >>> from colossalai.cluster import DistCoordinator + >>> dist_coordinator = DistCoordinator() + >>> with dist_coordinator.priority_execution(): + >>> dataset = CIFAR10(root='./data', download=True) + + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group. + """ + rank = dist.get_rank(group=process_group) + should_block = rank != executor_rank + + if should_block: + dist.barrier(group=process_group) + + yield + + if not should_block: + dist.barrier(group=process_group) + + def destroy(self, process_group: ProcessGroup = None): + """ + Destroy the distributed process group. + + Args: + process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group. + """ + dist.destroy_process_group(process_group) diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py new file mode 100644 index 000000000..e52661846 --- /dev/null +++ b/colossalai/cluster/process_group_manager.py @@ -0,0 +1,75 @@ +from typing import List + +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class ProcessGroupManager: + """ + ProcessGroupManager is used to manage the process groups in the cluster. + + There are some terms used in this class: + - pg: the short name for process group + - pg_name: the name of the process group + - pg_size: the world size of the process group + - rank: the rank of the current process in the process group + - world_size: the total number of processes in the process group + """ + + def __init__(self): + self.pg_store = dict() + + def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: + """ + Get a process group by name. If the process group does not exist, it will be created. + + Args: + name (str): name of the process group + ranks (List[int]): ranks of the process group + backend (str, optional): backend of the process group. Defaults to 'nccl'. + + Returns: + ProcessGroup: the process group + """ + if name not in self.pg_store: + pg = dist.new_group(ranks=ranks, backend=backend) + self.pg_store[name] = pg + return pg + else: + raise ValueError(f'Process group {name} already exists.') + + def get(self, name: str) -> ProcessGroup: + """ + Get a process group by name. + + Args: + name (str): name of the process group + + Returns: + ProcessGroup: the process group + """ + if name in self.pg_store: + return self.pg_store[name] + else: + raise ValueError(f'Process group {name} does not exist.') + + def destroy(self, name: str) -> None: + """ + Destroy a process group by name. + + Args: + name (str): name of the process group + """ + if name in self.pg_store: + dist.destroy_process_group(self.pg_store[name]) + del self.pg_store[name] + else: + raise ValueError(f'Process group {name} does not exist.') + + def destroy_all(self) -> None: + """ + Destroy all process groups. + """ + for name in self.pg_store: + dist.destroy_process_group(self.pg_store[name]) + self.pg_store.clear()