You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/cluster/process_group_mesh.py

228 lines
9.5 KiB

import gc
import itertools
from functools import reduce
from operator import mul
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
def prod(nums: List[int]) -> int:
"""Product of a list of numbers.
Args:
nums (List[int]): A list of numbers.
Returns:
int: The product of the numbers.
"""
return reduce(mul, nums)
class ProcessGroupMesh:
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
Args:
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
Attributes:
shape (Tuple[int, ...]): The shape of the process group mesh.
rank (int): The rank of the current process.
"""
def __init__(self, *size: int) -> None:
assert dist.is_initialized(), "Please initialize torch.distributed first."
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def __del__(self):
r"""
Destructor method for the ProcessGroupMesh class.
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
cleaning up any process groups that were created during the lifetime of the object.
Note:
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
# Manually clear all process groups to save memory
gc.collect()
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@property
def rank(self) -> int:
return self._rank
def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the size of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._shape
else:
return self._shape[dim]
def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
"""Get the coordinate of the process group mesh.
Args:
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
Returns:
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
"""
if dim is None:
return self._coord
else:
return self._coord[dim]
@staticmethod
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Convert a rank to a coordinate.
Args:
rank (int): Rank to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
Returns:
Tuple[int, ...]: Coordinate of the rank.
"""
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
ranks_in_group (List[int]): Ranks in the process group.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group with the given ranks.
"""
ranks_in_group = sorted(ranks_in_group)
if tuple(ranks_in_group) not in self._group_to_ranks:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
"""Get the ranks in the given process group. The process group must be created by this class.
Args:
group (ProcessGroup): The process group.
Returns:
List[int]: Ranks in the process group.
"""
return list(self._group_to_ranks[group])
@staticmethod
def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.
Args:
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
axis (int): Axis along which the coordinates are generated.
indices_at_axis (List[int]): Indices at the axis.
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group
def create_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
reduced_shape[axis] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
def get_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
1 year ago