From e3ad88fb482fdd95241a1f74866559b83ab4f56b Mon Sep 17 00:00:00 2001
From: Frank Lee <somerlee.9@gmail.com>
Date: Wed, 22 Mar 2023 14:11:54 +0800
Subject: [PATCH] [booster] implemented the cluster module (#3191)

* [booster] implemented the cluster module

* polish code
---
 colossalai/cluster/__init__.py              |   5 +
 colossalai/cluster/device_mesh_manager.py   |  36 +++++
 colossalai/cluster/dist_coordinator.py      | 158 ++++++++++++++++++++
 colossalai/cluster/process_group_manager.py |  75 ++++++++++
 4 files changed, 274 insertions(+)
 create mode 100644 colossalai/cluster/__init__.py
 create mode 100644 colossalai/cluster/device_mesh_manager.py
 create mode 100644 colossalai/cluster/dist_coordinator.py
 create mode 100644 colossalai/cluster/process_group_manager.py

diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
new file mode 100644
index 000000000..2fbdfd3cc
--- /dev/null
+++ b/colossalai/cluster/__init__.py
@@ -0,0 +1,5 @@
+from .device_mesh_manager import DeviceMeshManager
+from .dist_coordinator import DistCoordinator
+from .process_group_manager import ProcessGroupManager
+
+__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py
new file mode 100644
index 000000000..744799182
--- /dev/null
+++ b/colossalai/cluster/device_mesh_manager.py
@@ -0,0 +1,36 @@
+from colossalai.device.device_mesh import DeviceMesh
+
+
+class DeviceMeshManager:
+    """
+    Device mesh manager is responsible for creating and managing device meshes.
+    """
+
+    def __init__(self):
+        self.device_mesh_store = dict()
+
+    def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
+        """
+        Create a device mesh and store it in the manager.
+
+        Args:
+            name (str): name of the device mesh
+            *args: args for DeviceMesh
+            **kwargs: kwargs for DeviceMesh
+        """
+        # TODO(Yuliang): replace *args, **kwargs with explicit arguments
+        if name not in self.device_mesh_store:
+            device_mesh = DeviceMesh(*args, **kwargs)
+            self.device_mesh_store[name] = device_mesh
+            return device_mesh
+        else:
+            raise ValueError(f'Device mesh {name} already exists.')
+
+    def get(self, name: str) -> DeviceMesh:
+        pass
+
+    def destroy(self):
+        pass
+
+    def destroy_all(self):
+        pass
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
new file mode 100644
index 000000000..6b48faf5b
--- /dev/null
+++ b/colossalai/cluster/dist_coordinator.py
@@ -0,0 +1,158 @@
+import os
+from contextlib import contextmanager
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+from colossalai.context.singleton_meta import SingletonMeta
+
+
+class DistCoordinator(metaclass=SingletonMeta):
+    """
+    This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
+    class in the whole program.
+
+    There are some terms that are used in this class:
+        - rank: the rank of the current process
+        - world size: the total number of processes
+        - local rank: the rank of the current process on the current node
+        - master: the process with rank 0
+        - node master: the process with local rank 0 on the current node
+
+    Example:
+        >>> from colossalai.cluster.dist_coordinator import DistCoordinator
+        >>> coordinator = DistCoordinator()
+        >>>
+        >>> if coordinator.is_master():
+        >>>     do_something()
+        >>>
+        >>> coordinator.print_on_master('hello world')
+
+    Attributes:
+        rank (int): the rank of the current process
+        world_size (int): the total number of processes
+        local_rank (int): the rank of the current process on the current node
+    """
+
+    def __init__(self):
+        assert dist.is_initialized(
+        ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
+        self._rank = dist.get_rank()
+        self._world_size = dist.get_world_size()
+        # this is often passed by launchers such as torchrun
+        self._local_rank = os.environ.get('LOCAL_RANK', -1)
+
+    @property
+    def rank(self) -> int:
+        return self._rank
+
+    @property
+    def world_size(self) -> int:
+        return self._world_size
+
+    @property
+    def local_rank(self) -> int:
+        return self._local_rank
+
+    def _assert_local_rank_set(self):
+        """
+        Assert that the local rank is set. This is often passed by launchers such as torchrun.
+        """
+        assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
+
+    def is_master(self, process_group: ProcessGroup = None) -> bool:
+        """
+        Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
+
+        Args:
+            process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
+
+        Returns:
+            bool: True if the current process is the master process, False otherwise
+        """
+        rank = dist.get_rank(group=process_group)
+        return rank == 0
+
+    def is_node_master(self) -> bool:
+        """
+        Check if the current process is the master process on the current node (local rank is 0).
+
+        Returns:
+            bool: True if the current process is the master process on the current node, False otherwise
+        """
+        self._assert_local_rank_set()
+        return self.local_rank == 0
+
+    def is_last_process(self, process_group: ProcessGroup = None) -> bool:
+        """
+        Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
+
+        Args:
+            process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
+
+        Returns:
+            bool: True if the current process is the last process, False otherwise
+        """
+        rank = dist.get_rank(group=process_group)
+        world_size = dist.get_world_size(group=process_group)
+        return rank == world_size - 1
+
+    def print_on_master(self, msg: str, process_group: ProcessGroup = None):
+        """
+        Print message only from rank 0.
+
+        Args:
+            msg (str): message to print
+            process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
+        """
+        rank = dist.get_rank(group=process_group)
+        if rank == 0:
+            print(msg)
+
+    def print_on_node_master(self, msg: str):
+        """
+        Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
+
+        Args:
+            msg (str): message to print
+        """
+        self._assert_local_rank_set()
+        if self.local_rank == 0:
+            print(msg)
+
+    @contextmanager
+    def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
+        """
+        This context manager is used to allow one process to execute while blocking all
+        other processes in the same process group. This is often useful when downloading is required
+        as we only want to download in one process to prevent file corruption.
+
+        Example:
+            >>> from colossalai.cluster import DistCoordinator
+            >>> dist_coordinator = DistCoordinator()
+            >>> with dist_coordinator.priority_execution():
+            >>>     dataset = CIFAR10(root='./data', download=True)
+
+        Args:
+            executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
+            process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
+        """
+        rank = dist.get_rank(group=process_group)
+        should_block = rank != executor_rank
+
+        if should_block:
+            dist.barrier(group=process_group)
+
+        yield
+
+        if not should_block:
+            dist.barrier(group=process_group)
+
+    def destroy(self, process_group: ProcessGroup = None):
+        """
+        Destroy the distributed process group.
+
+        Args:
+            process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
+        """
+        dist.destroy_process_group(process_group)
diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py
new file mode 100644
index 000000000..e52661846
--- /dev/null
+++ b/colossalai/cluster/process_group_manager.py
@@ -0,0 +1,75 @@
+from typing import List
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+class ProcessGroupManager:
+    """
+    ProcessGroupManager is used to manage the process groups in the cluster.
+
+    There are some terms used in this class:
+        - pg: the short name for process group
+        - pg_name: the name of the process group
+        - pg_size: the world size of the process group
+        - rank: the rank of the current process in the process group
+        - world_size: the total number of processes in the process group
+    """
+
+    def __init__(self):
+        self.pg_store = dict()
+
+    def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
+        """
+        Get a process group by name. If the process group does not exist, it will be created.
+
+        Args:
+            name (str): name of the process group
+            ranks (List[int]): ranks of the process group
+            backend (str, optional): backend of the process group. Defaults to 'nccl'.
+
+        Returns:
+            ProcessGroup: the process group
+        """
+        if name not in self.pg_store:
+            pg = dist.new_group(ranks=ranks, backend=backend)
+            self.pg_store[name] = pg
+            return pg
+        else:
+            raise ValueError(f'Process group {name} already exists.')
+
+    def get(self, name: str) -> ProcessGroup:
+        """
+        Get a process group by name.
+
+        Args:
+            name (str): name of the process group
+
+        Returns:
+            ProcessGroup: the process group
+        """
+        if name in self.pg_store:
+            return self.pg_store[name]
+        else:
+            raise ValueError(f'Process group {name} does not exist.')
+
+    def destroy(self, name: str) -> None:
+        """
+        Destroy a process group by name.
+
+        Args:
+            name (str): name of the process group
+        """
+        if name in self.pg_store:
+            dist.destroy_process_group(self.pg_store[name])
+            del self.pg_store[name]
+        else:
+            raise ValueError(f'Process group {name} does not exist.')
+
+    def destroy_all(self) -> None:
+        """
+        Destroy all process groups.
+        """
+        for name in self.pg_store:
+            dist.destroy_process_group(self.pg_store[name])
+        self.pg_store.clear()