Commit Graph

3111 Commits (aabc9fb6aada9e7feb2ff8cf1f34e6ac37ade2e7)

Author SHA1 Message Date
Runyu Lu aabc9fb6aa [feat] add use_cuda_kernel option 2024-03-19 13:24:25 +08:00
Runyu Lu 6e30248683 [fix] tmp for test 2024-03-14 16:13:00 +08:00
Runyu Lu d02e257abd
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph 2024-03-14 10:37:05 +08:00
Runyu Lu ae24b4f025 diverse tests 2024-03-14 10:35:08 +08:00
Runyu Lu 1821a6dab0 [fix] pytest and fix dyn grid bug 2024-03-13 17:28:32 +08:00
yuehuayingxueluo f366a5ea1f
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)
* add rotary embedding kernel

* add rotary_embedding_kernel

* add fused rotary_emb and kvcache memcopy

* add fused_rotary_emb_and_cache_kernel.cu

* add fused_rotary_emb_and_memcopy

* fix bugs in fused_rotary_emb_and_cache_kernel.cu

* fix ci bugs

* use vec memcopy and opt the  gloabl memory access

* fix code style

* fix test_rotary_embdding_unpad.py

* codes revised based on the review comments

* fix bugs about include path

* rm inline
2024-03-13 17:20:03 +08:00
Steve Luo ed431de4e4
fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454) 2024-03-13 16:00:55 +08:00
傅剑寒 6fd355a5a6
Merge pull request #5452 from Courtesy-Xs/fix_include_path
fix include path
2024-03-13 11:26:41 +08:00
xs_courtesy c1c45e9d8e fix include path 2024-03-13 11:21:06 +08:00
Steve Luo b699f54007
optimize rmsnorm: add vectorized elementwise op, feat loop unrolling (#5441) 2024-03-12 17:48:02 +08:00
傅剑寒 368a2aa543
Merge pull request #5445 from Courtesy-Xs/refactor_infer_compilation
Refactor colossal-infer code arch
2024-03-12 14:14:37 +08:00
xs_courtesy 095c070a6e refactor code 2024-03-11 17:06:57 +08:00
傅剑寒 21e1e3645c
Merge pull request #5435 from Courtesy-Xs/add_gpu_launch_config
Add query and other components
2024-03-11 11:15:29 +08:00
Runyu Lu 633e95b301 [doc] add doc 2024-03-11 10:56:51 +08:00
Runyu Lu 9dec66fad6 [fix] multi graphs capture error 2024-03-11 10:51:16 +08:00
Runyu Lu b2c0d9ff2b [fix] multi graphs capture error 2024-03-11 10:49:31 +08:00
Steve Luo f7aecc0c6b
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417) 2024-03-08 16:21:12 +08:00
xs_courtesy 5eb5ff1464 refactor code 2024-03-08 15:41:14 +08:00
xs_courtesy 01d289d8e5 Merge branch 'feature/colossal-infer' of https://github.com/hpcaitech/ColossalAI into add_gpu_launch_config 2024-03-08 15:04:55 +08:00
xs_courtesy a46598ac59 add reusable utils for cuda 2024-03-08 14:53:29 +08:00
傅剑寒 2b28b54ac6
Merge pull request #5433 from Courtesy-Xs/add_silu_and_mul
【Inference】Add silu_and_mul for infer
2024-03-08 14:44:37 +08:00
Runyu Lu cefaeb5fdd [feat] cuda graph support and refactor non-functional api 2024-03-08 14:19:35 +08:00
xs_courtesy 95c21498d4 add silu_and_mul for infer 2024-03-07 16:57:49 +08:00
Frank Lee 593a72e4d5
Merge pull request #5424 from FrankLeeeee/sync/main
Sync/main
2024-03-04 10:13:59 +08:00
FrankLeeeee 0310b76e9d Merge branch 'main' into sync/main 2024-03-04 10:09:36 +08:00
Camille Zhong 4b8312c08e
fix sft single turn inference example (#5416) 2024-03-01 17:27:50 +08:00
binmakeswell a1c6cdb189 [doc] fix blog link 2024-02-29 15:01:43 +08:00
binmakeswell 5de940de32 [doc] fix blog link 2024-02-29 15:01:43 +08:00
Frank Lee 2461f37886
[workflow] added pypi channel (#5412) 2024-02-29 13:56:55 +08:00
Tong Li a28c971516
update requirements (#5407) 2024-02-28 17:46:27 +08:00
yuehuayingxueluo 0aa27f1961
[Inference]Move benchmark-related code to the example directory. (#5408)
* move benchmark-related code to the example directory.

* fix bugs in test_fused_rotary_embedding.py
2024-02-28 16:46:03 +08:00
yuehuayingxueluo 600881a8ea
[Inference]Add CUDA KVCache Kernel (#5406)
* add cuda KVCache kernel

* annotation benchmark_kvcache_copy

* add use cuda

* fix import path

* move benchmark scripts to example/

* rm benchmark codes in test_kv_cache_memcpy.py

* rm redundancy codes

* rm redundancy codes

* pr was modified according to the review
2024-02-28 14:36:50 +08:00
flybird11111 0a25e16e46
[shardformer]gather llama logits (#5398)
* gather llama logits

* fix
2024-02-27 22:44:07 +08:00
Frank Lee dcdd8a5ef7
[setup] fixed nightly release (#5388) 2024-02-27 15:19:13 +08:00
QinLuo bf34c6fef6
[fsdp] impl save/load shard model/optimizer (#5357) 2024-02-27 13:51:14 +08:00
Hongxin Liu d882d18c65
[example] reuse flash attn patch (#5400) 2024-02-27 11:22:07 +08:00
Hongxin Liu 95c21e3950
[extension] hotfix jit extension setup (#5402) 2024-02-26 19:46:58 +08:00
Yuanheng Zhao 19061188c3
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
fix dependency in pytest
2024-02-26 16:17:47 +08:00
yuehuayingxueluo bc1da87366
[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)
* Fix bugs in inference_engine

* fix bugs in engine.py

* rm  CUDA_VISIBLE_DEVICES

* add request_ids in generate

* fix bug in engine.py

* add logger.debug for BatchBucket
2024-02-23 10:51:35 +08:00
yuehuayingxueluo 2a718c8be8
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
2024-02-21 13:23:57 +08:00
Jianghai 730103819d
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
2024-02-21 11:31:48 +08:00
Stephan Kölker 5d380a1a21
[hotfix] Fix wrong import in meta_registry (#5392) 2024-02-20 19:24:43 +08:00
CZYCW b833153fd5
[hotfix] fix variable type for top_p (#5313)
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
2024-02-19 18:25:44 +08:00
Yuanheng Zhao b21aac5bae
[Inference] Optimize and Refactor Inference Batching/Scheduling (#5367)
* add kvcache manager funcs for batching

* add batch bucket for batching

* revise RunningList struct in handler

* add kvcache/batch funcs for compatibility

* use new batching methods

* fix indexing bugs

* revise abort logic

* use cpu seq lengths/block tables

* rm unused attr in Sequence

* fix type conversion/default arg

* add and revise pytests

* revise pytests, rm unused tests

* rm unused statements

* fix pop finished indexing issue

* fix: use index in batch when retrieving inputs/update seqs

* use dict instead of odict in batch struct

* arg type hinting

* fix make compress

* refine comments

* fix: pop_n_seqs to pop the first n seqs

* add check in request handler

* remove redundant conversion

* fix test for request handler

* fix pop method in batch bucket

* fix prefill adding
2024-02-19 17:18:20 +08:00
Frank Lee 705a62a565
[doc] updated installation command (#5389) 2024-02-19 16:54:03 +08:00
yixiaoer 69e3ad01ed
[doc] Fix typo (#5361) 2024-02-19 16:53:28 +08:00
Hongxin Liu 7303801854
[llama] fix training and inference scripts (#5384)
* [llama] refactor inference example to fit sft

* [llama] fix training script to fit gemini

* [llama] fix inference script
2024-02-19 16:41:04 +08:00
Hongxin Liu adae123df3
[release] update version (#5380) 2024-02-08 18:50:09 +08:00
Frank Lee efef43b53c
Merge pull request #5372 from hpcaitech/exp/mixtral 2024-02-08 16:30:05 +08:00
yuehuayingxueluo 8c69debdc7
[Inference]Support vllm testing in benchmark scripts (#5379)
* add vllm benchmark scripts

* fix code style

* update run_benchmark.sh

* fix code style
2024-02-08 15:27:26 +08:00