ColossalAI/colossalai/cluster/dist_coordinator.py

202 lines
7.1 KiB
Python
Raw Normal View History

import functools
import os
from contextlib import contextmanager
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context.singleton_meta import SingletonMeta
class DistCoordinator(metaclass=SingletonMeta):
"""
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
class in the whole program.
There are some terms that are used in this class:
- rank: the rank of the current process
- world size: the total number of processes
- local rank: the rank of the current process on the current node
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
```python
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
if coordinator.is_master():
do_something()
coordinator.print_on_master('hello world')
```
Attributes:
rank (int): the rank of the current process
world_size (int): the total number of processes
local_rank (int): the rank of the current process on the current node
"""
def __init__(self):
assert (
dist.is_initialized()
), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
[FP8] rebase main (#5963) * add SimPO * fix dataloader * remove debug code * add orpo * fix style * fix colossalai, transformers version * fix colossalai, transformers version * fix colossalai, transformers version * fix torch colossalai version * update transformers version * [shardformer] DeepseekMoE support (#5871) * [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support * [HotFix] CI,import,requirements-test for #5838 (#5892) * [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897) * add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint * fix style * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix eval * hotfix citation * [zero] support all-gather overlap (#5898) * [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api * fix orpo cross entropy loss * [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446) * Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method * [ShardFormer] fix qwen2 sp (#5903) * [compatibility] support torch 2.2 (#5875) * Support Pytorch 2.2.2 * keep build_on_pr file and update .compatibility * fix object_to_tensor usage when torch>=2.3.0 (#5820) * [misc] support torch2.3 (#5893) * [misc] support torch2.3 * [devops] update compatibility ci * [devops] update compatibility ci * [devops] add debug * [devops] add debug * [devops] add debug * [devops] add debug * [devops] remove debug * [devops] remove debug * [release] update version (#5912) * [plugin] support all-gather overlap for hybrid parallel (#5919) * [plugin] fixed all-gather overlap support for hybrid parallel * add kto * fix style, add kto data sample * [Examples] Add lazy init to OPT and GPT examples (#5924) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [ColossalChat] Hotfix for ColossalChat (#5910) * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * add ignore and tiny llama * fix path issue * run style * fix issue * update bash * fix ddp issue * add Qwen 1.5 32B * refactor tokenization * [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931) * cannot access local variable 'default_conversation' where it is not associated with a value set default value for 'default_conversation' * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix test data * refactor evaluation * remove real data path * remove real data path * Add n_fused as an input from native_module (#5894) * [FIX BUG] convert env param to int in (#5934) * [Hotfix] Fix ZeRO typo #5936 Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941) * Add a switch to control whether the model checkpoint needs to be saved after each epoch ends * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix style * fix style * fix style * [shardformer] hotfix attn mask (#5945) * [shardformer] hotfix attn mask (#5947) * [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895) * Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement * [zero] hotfix update master params (#5951) * [release] update version (#5952) * [Chat] Fix lora (#5946) * fix merging * remove filepath * fix style * Update README.md (#5958) * [hotfix] Remove unused plan section (#5957) * remove readme * fix readme * update * [test] add mixtral for sequence classification * [test] add mixtral transformer test * [moe] fix plugin * [test] mixtra pp shard test * [chore] handle non member group * [zero] solve hang * [test] pass mixtral shardformer test * [moe] implement transit between non moe tp and ep * [zero] solve hang * [misc] solve booster hang by rename the variable * solve hang when parallel mode = pp + dp * [moe] implement submesh initialization * [moe] add mixtral dp grad scaling when not all experts are activated * [chore] manually revert unintended commit * [chore] trivial fix * [chore] arg pass & remove drop token * [test] add mixtral modelling test * [moe] implement tp * [moe] test deepseek * [moe] clean legacy code * [Feature] MoE Ulysses Support (#5918) * moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [chore] minor fix * [moe] init moe plugin comm setting with sp * moe sp + ep bug fix * [moe] finalize test (no pp) * [moe] full test for deepseek and mixtral (pp + sp to fix) * [chore] minor fix after rebase * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] solve moe ckpt test failure and some other arg pass failure * [moe] remove ops * [test] fix test: test_zero1_2 * [bug] fix: somehow logger hangs the program * [moe] deepseek moe sp support * [test] add check * [deepseek] replace attn (a workaround for bug in transformers) * [misc] skip redunant test * [misc] remove debug/print code * [moe] refactor mesh assignment * Revert "[moe] implement submesh initialization" This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582. * [chore] change moe_pg_mesh to private * [misc] remove incompatible test config * [misc] fix ci failure: change default value to false in moe plugin * [misc] remove useless condition * [chore] docstring * [moe] remove force_overlap_comm flag and add warning instead * [doc] add MoeHybridParallelPlugin docstring * [moe] solve dp axis issue * [chore] remove redundant test case, print string & reduce test tokens * [feat] Dist Loader for Eval (#5950) * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support auto distributed data loader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tp error * remove unused parameters * remove unused * update inference * update docs * update inference --------- Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [lora] lora support hybrid parallel plugin (#5956) * lora support hybrid plugin * fix * fix * fix * fix * fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 * fix scaling algorithm in FP8 casting * support fp8 communication in pipeline parallelism * add fp8_communication flag in the script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shardformer fp8 * fix rebase * remove all to all * fix shardformer fp8 communication training degradation * [fp8] support all-gather flat tensor (#5932) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update low_level_optim.py --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Haze188 <haze188@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: zhurunhua <1281592874@qq.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: Gao, Ruiyuan <905370712@qq.com> Co-authored-by: hxwang <wang1570@e.ntu.edu.sg> Co-authored-by: Michelle <qianranma8@gmail.com> Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
self._local_rank = int(os.environ.get("LOCAL_RANK", -1))
@property
def rank(self) -> int:
return self._rank
@property
def world_size(self) -> int:
return self._world_size
@property
def local_rank(self) -> int:
return self._local_rank
def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert (
self.local_rank >= 0
), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the master process, False otherwise
"""
rank = dist.get_rank(group=process_group)
return rank == 0
def is_node_master(self) -> bool:
"""
Check if the current process is the master process on the current node (local rank is 0).
Returns:
bool: True if the current process is the master process on the current node, False otherwise
"""
self._assert_local_rank_set()
return self.local_rank == 0
def is_last_process(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
Args:
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
Returns:
bool: True if the current process is the last process, False otherwise
"""
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
return rank == world_size - 1
def print_on_master(self, msg: str, process_group: ProcessGroup = None):
"""
Print message only from rank 0.
Args:
msg (str): message to print
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
if rank == 0:
print(msg)
def print_on_node_master(self, msg: str):
"""
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
Args:
msg (str): message to print
"""
self._assert_local_rank_set()
if self.local_rank == 0:
print(msg)
@contextmanager
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
with dist_coordinator.priority_execution():
dataset = CIFAR10(root='./data', download=True)
```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
should_block = rank != executor_rank
if should_block:
self.block_all(process_group)
yield
if not should_block:
self.block_all(process_group)
def destroy(self, process_group: ProcessGroup = None):
"""
Destroy the distributed process group.
Args:
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
"""
dist.destroy_process_group(process_group)
def block_all(self, process_group: ProcessGroup = None):
"""
Block all processes in the process group.
Args:
process_group (ProcessGroup, optional): process group to block. Defaults to None, which refers to the default process group.
"""
dist.barrier(group=process_group)
def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
@dist_coordinator.on_master_only()
def print_on_master(msg):
print(msg)
```
"""
is_master = self.is_master(process_group)
# define an inner function
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
return func(*args, **kwargs)
return wrapper
return decorator