You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_booster/test_plugin/test_gemini_plugin.py

179 lines
6.1 KiB

from contextlib import nullcontext
from typing import Optional
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(
max_norm=1.0,
initial_scale=2**5,
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter"
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
1 year ago
except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
except Exception as e:
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
1 year ago
# raise e
return repr(e)
# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == "lazy":
return
passed_models = []
failed_info = {} # (model_name, error) pair
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
1 year ago
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items():
# These models lead to CUDA error
if name in (
"diffusers_auto_encoder_kl",
"diffusers_vq_model",
"diffusers_unet2d_model",
"timm_resmlp",
"timm_gmixer_12_224",
"timm_gmlp_b16_224",
"timm_mixer_b16_224",
"timm_convnext",
"torchvision_convnext_base",
):
continue
# These models are not compatible with gemini
if name in [
"timm_convit",
"timm_dm_nfnet",
"torchvision_vit_b_16",
"transformers_t5",
"transformers_t5_for_conditional_generation",
"transformers_t5_encoder_model", # does not support apex rmsnorm
"transformers_chatglm",
"transformers_sam",
"transformers_vit",
"transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini
"transformers_falcon", # TODO check why falcon fails to run Gemini
"transformers_falcon_for_causal_lm",
"transformers_falcon_for_sequence_classification",
"transformers_falcon_for_token_classification",
"transformers_falcon_for_question_answering",
]:
continue
if init_method == "lazy" and name in [
"timm_convmixer",
"timm_vision_transformer",
"timm_deit",
"timm_deit3",
"timm_inception_v3",
"timm_tnt_b_patch16_224",
"timm_rexnet",
"torchvision_densenet121",
"torchvision_efficientnet_b0",
"torchvision_mobilenet_v2",
"torchvision_mnasnet0_5",
"torchvision_regnet_x_16gf",
"torchvision_shufflenet_v2_x0_5",
"torchvision_efficientnet_v2_s",
]:
continue
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
1 year ago
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
tp_size = 1
[gemini] gemini support tensor parallelism. (#4942) * [colossalai]fix typo * [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license * Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 * [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test * [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions <github-actions@github.com> * [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case * [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit * [test] add no master test for low level zero plugin (#4934) * [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions <github-actions@github.com> * [nfc] fix some typo with colossalai/ docs/ etc. (#4920) * [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> * [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test * [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai <xukai16@foxamil.com> * [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test * updated c++17 compiler flags (#4983) * [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> * [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. * [hotfix] fix the bug of repeatedly storing param group (#4951) * [doc] add supported feature diagram for hybrid parallel plugin (#4996) * [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo * [release] update version (#4995) * [release] update version * [hotfix] fix ci * [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp * fix fix fix * update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO * support fused layernorm support fused layernorm support fused layernorm * update fusedlayernorm update fusedlayernorm update fusedlayernorm * add sequence parallel to gemini add sequence parallel to gemini * fix * fix comments fix comments fix comments * fix * fix t5 * clear cache * fix * activate ci * activate ci * fix * fix * fix * fix * revert * modify tp gather method modify tp gather method modify tp gather method modify tp gather method * fix test --------- Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com> Co-authored-by: Xu Kai <xukai16@foxamil.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
1 year ago
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f"Init method: {init_method}")
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_gemini_plugin(early_stop=early_stop)
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)