mirror of https://github.com/hpcaitech/ColossalAI
159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
|
import os
|
||
|
from contextlib import contextmanager
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
from torch.distributed import ProcessGroup
|
||
|
|
||
|
from colossalai.context.singleton_meta import SingletonMeta
|
||
|
|
||
|
|
||
|
class DistCoordinator(metaclass=SingletonMeta):
|
||
|
"""
|
||
|
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
|
||
|
class in the whole program.
|
||
|
|
||
|
There are some terms that are used in this class:
|
||
|
- rank: the rank of the current process
|
||
|
- world size: the total number of processes
|
||
|
- local rank: the rank of the current process on the current node
|
||
|
- master: the process with rank 0
|
||
|
- node master: the process with local rank 0 on the current node
|
||
|
|
||
|
Example:
|
||
|
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
|
||
|
>>> coordinator = DistCoordinator()
|
||
|
>>>
|
||
|
>>> if coordinator.is_master():
|
||
|
>>> do_something()
|
||
|
>>>
|
||
|
>>> coordinator.print_on_master('hello world')
|
||
|
|
||
|
Attributes:
|
||
|
rank (int): the rank of the current process
|
||
|
world_size (int): the total number of processes
|
||
|
local_rank (int): the rank of the current process on the current node
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
assert dist.is_initialized(
|
||
|
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
|
||
|
self._rank = dist.get_rank()
|
||
|
self._world_size = dist.get_world_size()
|
||
|
# this is often passed by launchers such as torchrun
|
||
|
self._local_rank = os.environ.get('LOCAL_RANK', -1)
|
||
|
|
||
|
@property
|
||
|
def rank(self) -> int:
|
||
|
return self._rank
|
||
|
|
||
|
@property
|
||
|
def world_size(self) -> int:
|
||
|
return self._world_size
|
||
|
|
||
|
@property
|
||
|
def local_rank(self) -> int:
|
||
|
return self._local_rank
|
||
|
|
||
|
def _assert_local_rank_set(self):
|
||
|
"""
|
||
|
Assert that the local rank is set. This is often passed by launchers such as torchrun.
|
||
|
"""
|
||
|
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
|
||
|
|
||
|
def is_master(self, process_group: ProcessGroup = None) -> bool:
|
||
|
"""
|
||
|
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
|
||
|
|
||
|
Args:
|
||
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if the current process is the master process, False otherwise
|
||
|
"""
|
||
|
rank = dist.get_rank(group=process_group)
|
||
|
return rank == 0
|
||
|
|
||
|
def is_node_master(self) -> bool:
|
||
|
"""
|
||
|
Check if the current process is the master process on the current node (local rank is 0).
|
||
|
|
||
|
Returns:
|
||
|
bool: True if the current process is the master process on the current node, False otherwise
|
||
|
"""
|
||
|
self._assert_local_rank_set()
|
||
|
return self.local_rank == 0
|
||
|
|
||
|
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
|
||
|
"""
|
||
|
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
|
||
|
|
||
|
Args:
|
||
|
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if the current process is the last process, False otherwise
|
||
|
"""
|
||
|
rank = dist.get_rank(group=process_group)
|
||
|
world_size = dist.get_world_size(group=process_group)
|
||
|
return rank == world_size - 1
|
||
|
|
||
|
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
|
||
|
"""
|
||
|
Print message only from rank 0.
|
||
|
|
||
|
Args:
|
||
|
msg (str): message to print
|
||
|
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
|
||
|
"""
|
||
|
rank = dist.get_rank(group=process_group)
|
||
|
if rank == 0:
|
||
|
print(msg)
|
||
|
|
||
|
def print_on_node_master(self, msg: str):
|
||
|
"""
|
||
|
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
|
||
|
|
||
|
Args:
|
||
|
msg (str): message to print
|
||
|
"""
|
||
|
self._assert_local_rank_set()
|
||
|
if self.local_rank == 0:
|
||
|
print(msg)
|
||
|
|
||
|
@contextmanager
|
||
|
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
|
||
|
"""
|
||
|
This context manager is used to allow one process to execute while blocking all
|
||
|
other processes in the same process group. This is often useful when downloading is required
|
||
|
as we only want to download in one process to prevent file corruption.
|
||
|
|
||
|
Example:
|
||
|
>>> from colossalai.cluster import DistCoordinator
|
||
|
>>> dist_coordinator = DistCoordinator()
|
||
|
>>> with dist_coordinator.priority_execution():
|
||
|
>>> dataset = CIFAR10(root='./data', download=True)
|
||
|
|
||
|
Args:
|
||
|
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||
|
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
|
||
|
"""
|
||
|
rank = dist.get_rank(group=process_group)
|
||
|
should_block = rank != executor_rank
|
||
|
|
||
|
if should_block:
|
||
|
dist.barrier(group=process_group)
|
||
|
|
||
|
yield
|
||
|
|
||
|
if not should_block:
|
||
|
dist.barrier(group=process_group)
|
||
|
|
||
|
def destroy(self, process_group: ProcessGroup = None):
|
||
|
"""
|
||
|
Destroy the distributed process group.
|
||
|
|
||
|
Args:
|
||
|
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
|
||
|
"""
|
||
|
dist.destroy_process_group(process_group)
|