Commit Graph

56 Commits (62cdac6b7b655e11626382d64e56503146a516ee)

Author SHA1 Message Date
Yuanheng Zhao bd38fe6b91
[NFC] Fix code factors on inference triton kernels (#5743) 2024-05-21 22:12:15 +08:00
Yuanheng Zhao 537a3cbc4d
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-05-03 17:20:45 +08:00
Yuanheng Zhao 5be590b99e
[kernel] Support new KCache Layout - Context Attention Triton Kernel (#5658)
* add context attn triton kernel - new kcache layout

* add benchmark triton

* tiny revise

* trivial - code style, comment
2024-04-26 17:51:49 +08:00
yuehuayingxueluo 3c91e3f176
[Inference]Adapt to baichuan2 13B (#5614)
* adapt to baichuan2 13B

* adapt to baichuan2 13B

* change BAICHUAN_MODEL_NAME_OR_PATH

* fix test_decoding_attn.py

* Modifications based on review comments.

* change BAICHUAN_MODEL_NAME_OR_PATH

* mv attn mask processes to test flash decoding

* mv get_alibi_slopes baichuan modeling

* fix bugs in test_baichuan.py
2024-04-25 23:11:30 +08:00
Yuanheng Zhao 5d4c1fe8f5
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-23 13:09:55 +08:00
Yuanheng Zhao a37f82629d [Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
2024-04-10 11:07:52 +08:00
Yuanheng Zhao d63c469f45 [Infer] Revise and Adapt Triton Kernels for Spec-Dec (#5401)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)

fix dependency in pytest

* resolve conflicts for revising flash-attn

* adapt kv cache copy kernel for spec-dec

* fix seqlen-n kvcache copy kernel/tests

* test kvcache copy - use torch.equal

* add assertions

* (trivial) comment out
2024-04-10 11:07:51 +08:00
Yuanheng 7ca1d1c545 remove outdated triton test 2024-04-08 17:00:55 +08:00
Yuanheng ce9401ad52 remove unused triton kernels 2024-04-08 16:25:12 +08:00
Yuanheng ed5ebd1735 [Fix] resolve conflicts of merging main 2024-04-08 16:21:47 +08:00
Hongxin Liu 641b1ee71a
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-08 15:09:40 +08:00
Runyu Lu b2c0d9ff2b [fix] multi graphs capture error 2024-03-11 10:49:31 +08:00
Runyu Lu cefaeb5fdd [feat] cuda graph support and refactor non-functional api 2024-03-08 14:19:35 +08:00
yuehuayingxueluo 2a718c8be8
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
2024-02-21 13:23:57 +08:00
Jianghai 730103819d
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
2024-02-21 11:31:48 +08:00
yuehuayingxueluo 6fb4bcbb24
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy

* add TODO in test_kvcache_copy.py
2024-02-07 17:15:42 +08:00
Frank Lee 8106ede07f
Revert "[Inference] Adapt to Fused rotary (#5348)" (#5373)
This reverts commit 9f4ab2eb92.
2024-02-07 14:27:04 +08:00
Jianghai 9f4ab2eb92
[Inference] Adapt to Fused rotary (#5348)
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix
2024-02-07 11:36:04 +08:00
yuehuayingxueluo 35382a7fbf
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
2024-02-06 19:38:25 +08:00
yuehuayingxueluo 21ad4a27f9
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
* opt rms_norm

* fix bugs in rms_layernorm
2024-02-02 15:06:01 +08:00
yuehuayingxueluo 249644c23b
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
2024-02-01 15:49:39 +08:00
Jianghai df0aa49585
[Inference] Kernel Fusion, fused copy kv cache into rotary embedding (#5336)
* revise rotary embedding

* remove useless print

* adapt
2024-01-31 16:31:29 +08:00
Yuanheng Zhao 5f98a9d68a
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
2024-01-30 16:06:09 +08:00
Jianghai 1f8a75d470
[Inference] Update rms norm kernel, benchmark with vLLM (#5315)
* add

* xi

* del

* del

* fix
2024-01-29 10:22:33 +08:00
Jianghai 7ddd8b37f0
fix (#5311) 2024-01-26 15:02:12 +08:00
yuehuayingxueluo 4f28cb43c0
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
2024-01-26 14:00:10 +08:00
Yuanheng Zhao af8359c430
[hotfix] fix boundary check in batch (#5306) 2024-01-25 10:23:12 +08:00
Jianghai c647e00e3c
[Inference]Add fused rotary kernel and get cos cache kernel (#5302)
* add fused rotary and get cos cache func

* staged

* fix bugs

* fix bugs
2024-01-24 16:20:42 +08:00
Yuanheng Zhao 3da9993b0d
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest

* revise and add triton context attn benchmark
2024-01-23 17:16:02 +08:00
yuehuayingxueluo bfff9254ac
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
2024-01-22 10:55:34 +08:00
Yuanheng Zhao 6e487e7d3c
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors

* add singleton class holding intermediate values

* fix triton kernel api

* add benchmark in pytest

* fix kernel api and add benchmark

* revise flash decoding triton kernel in/out shapes

* fix calling of triton kernel in modeling

* fix pytest: extract to util functions
2024-01-19 15:47:16 +08:00
Yaozheng Fang 5ae9099f92
[kernel] Add RMSLayerNorm triton kernel (#5262)
* add layerrmsnorm triton kernel

* add layerrmsnorm kernel

* modify the atol and rtol in test file

* Remove the logics of mean computations, and update the name of ther kernel functions and files

* add benchmark of rms norm
2024-01-18 10:21:03 +08:00
Yuanheng Zhao 0f2b46a41c
[kernel] Revise KVCache copy triton kernel API (#5273)
* [kernel/fix] revise kvcache copy kernel api

* fix benchmark
2024-01-16 14:41:02 +08:00
Yuanheng Zhao fa85e02b3b
[kernel] Add KV cache copy kernel during decoding (#5261)
* add kv copy triton kernel during decoding stage

* add pytest and fix kernel

* fix test utilities

* revise kernel config

* add benchmark for kvcache copy
2024-01-15 17:37:20 +08:00
Yuanheng Zhao 1513f20f4d [kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
* add flash decoding unpad triton kernel

* rename flash decoding kernel

* add kernel testing (draft)

* revise pytest

* support kv group (GQA)

* (trivial) fix api and pytest

* (trivial) func renaming

* (trivial) func/file renaming

* refactor pytest for attention

* (trivial) format and consistent vars of context/decode attn

* (trivial) remove test redundancy
2024-01-11 13:46:14 +00:00
Jianghai fded91d049 [Inference] Kernel: no pad rotary embedding (#5252)
* fix bugs

* comment

* use more accurate atol

* fix
2024-01-11 13:46:14 +00:00
Yuanheng Zhao 07b5283b6a [kernel] Add triton kernel for context attention (FAv2) without padding (#5192)
* add context attn unpadded triton kernel

* test compatibility

* kv cache copy (testing)

* fix k/v cache copy

* fix kv cache copy and test

* fix boundary of block ptrs

* add support for GQA/MQA and testing

* fix import statement

---------

Co-authored-by: Round Heng <yuanhengzhao@Rounds-MacBook-Pro.local>
2024-01-11 13:39:56 +00:00
Yuanheng Zhao 2bb92243d4 [Inference/NFC] Clean outdated inference tests and deprecated kernels (#5159)
* [inference/nfc] remove outdated inference tests

* remove outdated kernel tests

* remove deprecated triton kernels

* remove imports from deprecated kernels
2024-01-11 13:39:29 +00:00
Cuiqing Li (李崔卿) bce919708f
[Kernels]added flash-decoidng of triton (#5063)
* added flash-decoidng of triton based on lightllm kernel

* add req

* clean

* clean

* delete build.sh

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
2023-11-20 13:58:29 +08:00
Cuiqing Li (李崔卿) 28052a71fb
[Kernels]Update triton kernels into 2.1.0 (#5046)
* update flash-context-attention

* adding kernels

* fix

* reset

* add build script

* add building process

* add llama2 exmaple

* add colossal-llama2 test

* clean

* fall back test setting

* fix test file

* clean

* clean

* clean

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
2023-11-16 16:43:15 +08:00
Xuanlei Zhao dc003c304c
[moe] merge moe into main (#4978)
* update moe module
* support openmoe
2023-11-02 02:21:24 +00:00
Cuiqing Li 459a88c806
[Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965)
* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
2023-10-30 14:04:37 +08:00
Jianghai cf579ff46d
[Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
2023-10-30 10:52:19 +08:00
Xu Kai 785802e809
[inference] add reference and fix some bugs (#4937)
* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>
2023-10-20 13:39:34 +08:00
Cuiqing Li 3a41e8304e
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
2023-10-19 22:22:47 +08:00
Xu Kai 611a5a80ca
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license
2023-10-16 11:28:44 +08:00
Xu Kai 77a9328304
[inference] add llama2 support (#4898)
* add llama2 support

* fix multi group bug
2023-10-13 13:09:23 +08:00
Jianghai 013a4bedf0
[inference]fix import bug and delete down useless init (#4830)
* fix import bug and release useless init

* fix

* fix

* fix
2023-10-04 09:18:45 +08:00
Xu Kai c3bef20478
add autotune (#4822) 2023-09-28 13:47:35 +08:00
Jianghai ce7ade3882
[inference] chatglm2 infer demo (#4724)
* add chatglm2

* add

* gather needed kernels

* fix some bugs

* finish context forward

* finish context stage

* fix

* add

* pause

* add

* fix bugs

* finish chatglm

* fix bug

* change some logic

* fix bugs

* change some logics

* add

* add

* add

* fix

* fix tests

* fix
2023-09-22 11:12:50 +08:00