|
|
|
import functools
|
|
|
|
import os
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
|
|
from colossalai.context.singleton_meta import SingletonMeta
|
|
|
|
|
|
|
|
|
|
|
|
class DistCoordinator(metaclass=SingletonMeta):
|
|
|
|
"""
|
|
|
|
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
|
|
|
|
class in the whole program.
|
|
|
|
|
|
|
|
There are some terms that are used in this class:
|
|
|
|
- rank: the rank of the current process
|
|
|
|
- world size: the total number of processes
|
|
|
|
- local rank: the rank of the current process on the current node
|
|
|
|
- master: the process with rank 0
|
|
|
|
- node master: the process with local rank 0 on the current node
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
from colossalai.cluster.dist_coordinator import DistCoordinator
|
|
|
|
coordinator = DistCoordinator()
|
|
|
|
|
|
|
|
if coordinator.is_master():
|
|
|
|
do_something()
|
|
|
|
|
|
|
|
coordinator.print_on_master('hello world')
|
|
|
|
```
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
rank (int): the rank of the current process
|
|
|
|
world_size (int): the total number of processes
|
|
|
|
local_rank (int): the rank of the current process on the current node
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
assert (
|
|
|
|
dist.is_initialized()
|
|
|
|
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
|
|
|
|
self._rank = dist.get_rank()
|
|
|
|
self._world_size = dist.get_world_size()
|
|
|
|
# this is often passed by launchers such as torchrun
|
|
|
|
self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rank(self) -> int:
|
|
|
|
return self._rank
|
|
|
|
|
|
|
|
@property
|
|
|
|
def world_size(self) -> int:
|
|
|
|
return self._world_size
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_rank(self) -> int:
|
|
|
|
return self._local_rank
|
|
|
|
|
|
|
|
def _assert_local_rank_set(self):
|
|
|
|
"""
|
|
|
|
Assert that the local rank is set. This is often passed by launchers such as torchrun.
|
|
|
|
"""
|
|
|
|
assert (
|
|
|
|
self.local_rank >= 0
|
|
|
|
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
|
|
|
|
|
|
|
|
def is_master(self, process_group: ProcessGroup = None) -> bool:
|
|
|
|
"""
|
|
|
|
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the current process is the master process, False otherwise
|
|
|
|
"""
|
|
|
|
rank = dist.get_rank(group=process_group)
|
|
|
|
return rank == 0
|
|
|
|
|
|
|
|
def is_node_master(self) -> bool:
|
|
|
|
"""
|
|
|
|
Check if the current process is the master process on the current node (local rank is 0).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the current process is the master process on the current node, False otherwise
|
|
|
|
"""
|
|
|
|
self._assert_local_rank_set()
|
|
|
|
return self.local_rank == 0
|
|
|
|
|
|
|
|
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
|
|
|
|
"""
|
|
|
|
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: True if the current process is the last process, False otherwise
|
|
|
|
"""
|
|
|
|
rank = dist.get_rank(group=process_group)
|
|
|
|
world_size = dist.get_world_size(group=process_group)
|
|
|
|
return rank == world_size - 1
|
|
|
|
|
|
|
|
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
|
|
|
|
"""
|
|
|
|
Print message only from rank 0.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
msg (str): message to print
|
|
|
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
|
|
|
"""
|
|
|
|
rank = dist.get_rank(group=process_group)
|
|
|
|
if rank == 0:
|
|
|
|
print(msg)
|
|
|
|
|
|
|
|
def print_on_node_master(self, msg: str):
|
|
|
|
"""
|
|
|
|
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
msg (str): message to print
|
|
|
|
"""
|
|
|
|
self._assert_local_rank_set()
|
|
|
|
if self.local_rank == 0:
|
|
|
|
print(msg)
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
|
|
|
|
"""
|
|
|
|
This context manager is used to allow one process to execute while blocking all
|
|
|
|
other processes in the same process group. This is often useful when downloading is required
|
|
|
|
as we only want to download in one process to prevent file corruption.
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
dist_coordinator = DistCoordinator()
|
|
|
|
with dist_coordinator.priority_execution():
|
|
|
|
dataset = CIFAR10(root='./data', download=True)
|
|
|
|
```
|
|
|
|
|
|
|
|
Args:
|
|
|
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
|
|
|
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
|
|
|
|
"""
|
|
|
|
rank = dist.get_rank(group=process_group)
|
|
|
|
should_block = rank != executor_rank
|
|
|
|
|
|
|
|
if should_block:
|
|
|
|
self.block_all(process_group)
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
if not should_block:
|
|
|
|
self.block_all(process_group)
|
|
|
|
|
|
|
|
def destroy(self, process_group: ProcessGroup = None):
|
|
|
|
"""
|
|
|
|
Destroy the distributed process group.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
|
|
|
"""
|
|
|
|
dist.destroy_process_group(process_group)
|
|
|
|
|
|
|
|
def block_all(self, process_group: ProcessGroup = None):
|
|
|
|
"""
|
|
|
|
Block all processes in the process group.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
|
|
|
|
"""
|
|
|
|
dist.barrier(group=process_group)
|
|
|
|
|
|
|
|
def on_master_only(self, process_group: ProcessGroup = None):
|
|
|
|
"""
|
|
|
|
A function wrapper that only executes the wrapped function on the master process (rank 0).
|
|
|
|
|
|
|
|
```python
|
|
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
dist_coordinator = DistCoordinator()
|
|
|
|
|
|
|
|
@dist_coordinator.on_master_only()
|
|
|
|
def print_on_master(msg):
|
|
|
|
print(msg)
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
is_master = self.is_master(process_group)
|
|
|
|
|
|
|
|
# define an inner function
|
|
|
|
def decorator(func):
|
|
|
|
@functools.wraps(func)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
if is_master:
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
return decorator
|