ColossalAI/colossalai/tensor/d_tensor/api.py

541 lines
18 KiB
Python
Raw Normal View History

import copy
import operator
from functools import reduce
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) * [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 05:52:45 +00:00
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
from .layout import Layout
from .layout_converter import LayoutConverter
from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) * [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 05:52:45 +00:00
_SHARD_DIM = DimSpec([0])
def get_shard_dim_1d(p: torch.Tensor):
"""
Get the dimension along which the tensor is sharded, for example in 1D Tensor Parallel.
Args:
p (torch.Tensor): the input tensor
Returns:
int: the dimension along which the tensor is sharded
"""
if not is_distributed_tensor(p):
raise ValueError("p is not a distributed tensor")
sharding = p.dist_layout.sharding_spec.sharding_sequence
return sharding.index(_SHARD_DIM)
def clear_layout_converter():
global layout_converter
layout_converter.cached_solution.clear()
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a distributed tensor.
"""
return hasattr(tensor, "dist_layout")
def is_sharded(dtensor: torch.Tensor) -> bool:
"""
Check if a tensor is sharded.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def _construct_default_sharding_spec(
tensor: torch.Tensor,
) -> ShardingSpec:
"""
Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
"""
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
"""
Apply the layout to the local tensor during initializing process.
"""
2024-01-25 05:56:27 +00:00
# layout converter requires a source and target layout
# we construct the source layer for an unsharded tensor
2024-01-25 05:56:27 +00:00
# and use self.dist_layer as the target layout for the sharded tensor
source_spec = _construct_default_sharding_spec(tensor)
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
return sharded_tensor
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Convert the given tensor to a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be converted.
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
sharded_tensor = _apply_layout(tensor, dist_layout)
# hack some tensor methods
_hijack_detach_and_clone(sharded_tensor)
return sharded_tensor
def init_as_dtensor(
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
) -> torch.Tensor:
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
2023-11-10 02:15:16 +00:00
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
# shard tensor
tensor.dist_layout = dist_layout
# hack some tensor methods
_hijack_detach_and_clone(tensor)
return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
"""
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
)
return resharded_tensor
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
device_mesh = get_device_mesh(dtensor)
global_shape = get_global_shape(dtensor)
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
return global_tensor
def shard_rowwise(
tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
return param
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
def get_layout(dtensor: torch.Tensor) -> Layout:
"""
Get the layout of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
"""
Get the global shape of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.global_shape
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
"""
Get the device mesh of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.device_mesh
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
Get the sharding spec of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
return dtensor.dist_layout.sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def is_customized_distributed_tensor(tensor: torch.Tensor):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
sharded_tensor = shard_fn(tensor)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor.shard_fn = shard_fn
sharded_tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
return sharded_tensor
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
2023-11-10 02:15:16 +00:00
def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
return tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
return dtensor.gather_fn(dtensor)
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)