Commit Graph

110 Commits (a37f82629d7b9e3c3a0f430b8dd3ff6f38ddf1d4)

Author SHA1 Message Date
Yuanheng Zhao a37f82629d [Inference/SpecDec] Add Speculative Decoding Implementation (#5423)
* fix flash decoding mask during verification

* add spec-dec

* add test for spec-dec

* revise drafter init

* remove drafter sampling

* retire past kv in drafter

* (trivial) rename attrs

* (trivial) rename arg

* revise how we enable/disable spec-dec
2024-04-10 11:07:52 +08:00
Yuanheng Zhao 5a9b05f7b2 [Inference/SpecDec] Add Basic Drafter Model Container (#5405)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)

fix dependency in pytest

* add drafter model container (basic ver)
2024-04-10 11:07:51 +08:00
Yuanheng Zhao 4bb5d8923a
[Fix/Inference] Remove unused and non-functional functions (#5543)
* [fix] remove unused func

* rm non-functional partial
2024-04-02 14:16:59 +08:00
yuehuayingxueluo 04aca9e55b
[Inference/Kernel]Add get_cos_and_sin Kernel (#5528)
* Add get_cos_and_sin kernel

* fix code comments

* fix code typos

* merge common codes of get_cos_and_sin kernel.

* Fixed a typo

* Changed 'asset allclose' to 'assert equal'.
2024-04-01 13:47:14 +08:00
傅剑寒 e6496dd371
[Inference] Optimize request handler of llama (#5512)
* optimize request_handler

* fix ways of writing
2024-03-26 16:37:14 +08:00
Runyu Lu 6251d68dc9
[fix] PR #5354 (#5501)
* [fix]

* [fix]

* Update config.py docstring

* [fix] docstring align

* [fix] docstring align

* [fix] docstring align
2024-03-25 15:24:17 +08:00
Runyu Lu 68e9396bc0 [fix] merge conflicts 2024-03-25 14:48:28 +08:00
yuehuayingxueluo 87079cffe8
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2

* fix bugs in test_kv_cache_memcpy.py

* add context_kv_cache_memcpy_kernel.cu

* rm typename MT

* add tail process

* add high_precision

* add high_precision to config.py

* rm unused code

* change the comment for the high_precision parameter

* update test_rotary_embdding_unpad.py

* fix vector_copy_utils.h

* add comment for self.high_precision when using float32
2024-03-25 13:40:34 +08:00
Runyu Lu ff4998c6f3 [fix] remove unused comment 2024-03-25 12:00:57 +08:00
Runyu Lu 5b017d6324 [fix] 2024-03-21 15:55:25 +08:00
Runyu Lu 4eafe0c814 [fix] unused option 2024-03-21 11:28:42 +08:00
Runyu Lu aabc9fb6aa [feat] add use_cuda_kernel option 2024-03-19 13:24:25 +08:00
Runyu Lu 6e30248683 [fix] tmp for test 2024-03-14 16:13:00 +08:00
Runyu Lu d02e257abd
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph 2024-03-14 10:37:05 +08:00
Runyu Lu ae24b4f025 diverse tests 2024-03-14 10:35:08 +08:00
Runyu Lu 1821a6dab0 [fix] pytest and fix dyn grid bug 2024-03-13 17:28:32 +08:00
yuehuayingxueluo f366a5ea1f
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)
* add rotary embedding kernel

* add rotary_embedding_kernel

* add fused rotary_emb and kvcache memcopy

* add fused_rotary_emb_and_cache_kernel.cu

* add fused_rotary_emb_and_memcopy

* fix bugs in fused_rotary_emb_and_cache_kernel.cu

* fix ci bugs

* use vec memcopy and opt the  gloabl memory access

* fix code style

* fix test_rotary_embdding_unpad.py

* codes revised based on the review comments

* fix bugs about include path

* rm inline
2024-03-13 17:20:03 +08:00
Runyu Lu 633e95b301 [doc] add doc 2024-03-11 10:56:51 +08:00
Runyu Lu 9dec66fad6 [fix] multi graphs capture error 2024-03-11 10:51:16 +08:00
Runyu Lu b2c0d9ff2b [fix] multi graphs capture error 2024-03-11 10:49:31 +08:00
Steve Luo f7aecc0c6b
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417) 2024-03-08 16:21:12 +08:00
Runyu Lu cefaeb5fdd [feat] cuda graph support and refactor non-functional api 2024-03-08 14:19:35 +08:00
yuehuayingxueluo 600881a8ea
[Inference]Add CUDA KVCache Kernel (#5406)
* add cuda KVCache kernel

* annotation benchmark_kvcache_copy

* add use cuda

* fix import path

* move benchmark scripts to example/

* rm benchmark codes in test_kv_cache_memcpy.py

* rm redundancy codes

* rm redundancy codes

* pr was modified according to the review
2024-02-28 14:36:50 +08:00
yuehuayingxueluo bc1da87366
[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)
* Fix bugs in inference_engine

* fix bugs in engine.py

* rm  CUDA_VISIBLE_DEVICES

* add request_ids in generate

* fix bug in engine.py

* add logger.debug for BatchBucket
2024-02-23 10:51:35 +08:00
yuehuayingxueluo 2a718c8be8
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
2024-02-21 13:23:57 +08:00
Jianghai 730103819d
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
2024-02-21 11:31:48 +08:00
Yuanheng Zhao b21aac5bae
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
2024-02-19 17:18:20 +08:00
yuehuayingxueluo 8c69debdc7
[Inference]Support vllm testing in benchmark scripts (#5379)
* add vllm benchmark scripts

* fix code style

* update run_benchmark.sh

* fix code style
2024-02-08 15:27:26 +08:00
Frank Lee 9afa52061f
[inference] refactored config (#5376) 2024-02-08 14:04:14 +08:00
Jianghai 1f8c7e7046
[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)
* add

* fix

* fix

* pause

* fix

* fix pytest

* align

* fix

* license

* fix

* fix

* fix readme

* fix some bugs

* remove tokenizer config
2024-02-07 17:55:48 +08:00
yuehuayingxueluo 6fb4bcbb24
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy

* add TODO in test_kvcache_copy.py
2024-02-07 17:15:42 +08:00
Frank Lee 58740b5f68
[inference] added inference template (#5375) 2024-02-07 17:11:43 +08:00
Frank Lee 8106ede07f
Revert "[Inference] Adapt to Fused rotary (#5348)" (#5373)
This reverts commit 9f4ab2eb92.
2024-02-07 14:27:04 +08:00
Jianghai 9f4ab2eb92
[Inference] Adapt to Fused rotary (#5348)
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix
2024-02-07 11:36:04 +08:00
yuehuayingxueluo 35382a7fbf
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
2024-02-06 19:38:25 +08:00
Yuanheng Zhao 1dedb57747
[Fix/Infer] Remove unused deps and revise requirements (#5341)
* remove flash-attn dep

* rm padding llama

* revise infer requirements

* move requirements out of module
2024-02-06 17:27:45 +08:00
yuehuayingxueluo 631862f339
[Inference]Optimize generation process of inference engine (#5356)
* opt inference engine

* fix run_benchmark.sh

* fix generate in engine.py

* rollback tesh_inference_engine.py
2024-02-02 15:38:21 +08:00
yuehuayingxueluo 21ad4a27f9
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
* opt rms_norm

* fix bugs in rms_layernorm
2024-02-02 15:06:01 +08:00
Frank Lee 027aa1043f
[doc] updated inference readme (#5343) 2024-02-02 14:31:10 +08:00
Frank Lee db1a763307
[inference] removed redundancy init_batch (#5353) 2024-02-02 11:44:15 +08:00
yuehuayingxueluo 249644c23b
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
2024-02-01 15:49:39 +08:00
Frank Lee f8e456d202
[inference] simplified config verification (#5346)
* [inference] simplified config verification

* polish

* polish
2024-02-01 15:31:01 +08:00
Yuanheng Zhao 5f98a9d68a
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
2024-01-30 16:06:09 +08:00
yuehuayingxueluo e8f0642f28
[Inference]Add Nopadding Llama Modeling (#5327)
* add nopadding llama modeling

* add nopadding_llama.py

* rm unused codes

* fix bugs in test_xine_copy.py

* fix code style
2024-01-30 10:31:46 +08:00
Jianghai c7c104cb7c
[DOC] Update inference readme (#5280)
* add readme

* add readme

* 1

* update engine

* finish readme

* add readme
2024-01-29 16:21:06 +08:00
yuehuayingxueluo 4f28cb43c0
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
2024-01-26 14:00:10 +08:00
Yuanheng Zhao 3da9993b0d
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest

* revise and add triton context attn benchmark
2024-01-23 17:16:02 +08:00
yuehuayingxueluo cea9c86e45 add utils.py 2024-01-22 16:06:27 +08:00
yuehuayingxueluo bfff9254ac
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
2024-01-22 10:55:34 +08:00
Yuanheng Zhao 6e487e7d3c
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors

* add singleton class holding intermediate values

* fix triton kernel api

* add benchmark in pytest

* fix kernel api and add benchmark

* revise flash decoding triton kernel in/out shapes

* fix calling of triton kernel in modeling

* fix pytest: extract to util functions
2024-01-19 15:47:16 +08:00