Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

201 lines
7.1 KiB

import functools
import os
from contextlib import contextmanager
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context.singleton_meta import SingletonMeta
class DistCoordinator(metaclass=SingletonMeta):
"""
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
class in the whole program.
There are some terms that are used in this class:
- rank: the rank of the current process
- world size: the total number of processes
- local rank: the rank of the current process on the current node
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
```python
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
if coordinator.is_master():
do_something()
coordinator.print_on_master('hello world')
```
Attributes:
rank (int): the rank of the current process
world_size (int): the total number of processes
local_rank (int): the rank of the current process on the current node
"""
def __init__(self):
assert (
dist.is_initialized()
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
@property
def rank(self) -> int:
return self._rank
@property
def world_size(self) -> int:
return self._world_size
@property
def local_rank(self) -> int:
return self._local_rank
def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert (
self.local_rank >= 0
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the master process, False otherwise
"""
rank = dist.get_rank(group=process_group)
return rank == 0
def is_node_master(self) -> bool:
"""
Check if the current process is the master process on the current node (local rank is 0).
Returns:
bool: True if the current process is the master process on the current node, False otherwise
"""
self._assert_local_rank_set()
return self.local_rank == 0
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the last process, False otherwise
"""
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
return rank == world_size - 1
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
"""
Print message only from rank 0.
Args:
msg (str): message to print
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
if rank == 0:
print(msg)
def print_on_node_master(self, msg: str):
"""
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
Args:
msg (str): message to print
"""
self._assert_local_rank_set()
if self.local_rank == 0:
print(msg)
@contextmanager
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
with dist_coordinator.priority_execution():
dataset = CIFAR10(root='./data', download=True)
```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
should_block = rank != executor_rank
if should_block:
self.block_all(process_group)
yield
if not should_block:
self.block_all(process_group)
def destroy(self, process_group: ProcessGroup = None):
"""
Destroy the distributed process group.
Args:
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
"""
dist.destroy_process_group(process_group)
def block_all(self, process_group: ProcessGroup = None):
"""
Block all processes in the process group.
Args:
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
"""
dist.barrier(group=process_group)
def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
@dist_coordinator.on_master_only()
def print_on_master(msg):
print(msg)
```
"""
is_master = self.is_master(process_group)
# define an inner function
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
return func(*args, **kwargs)
return wrapper
return decorator