Commit Graph

14 Commits (5f8c0a0ac3b52a71b664c3e36dd1a8cef40f428d)

Author SHA1 Message Date
yuehuayingxueluo 5f00002e43
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B

* add baichuan2 13B TP

* update baichuan tp logic

* rm unused code

* Fix TP logic

* fix alibi slopes tp logic

* rm nn.Module

* Polished the code.

* change BAICHUAN_MODEL_NAME_OR_PATH

* Modified the logic for loading Baichuan weights.

* fix typos
2024-04-30 15:47:07 +08:00
yuehuayingxueluo 35382a7fbf
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
2024-02-06 19:38:25 +08:00
Frank Lee 027aa1043f
[doc] updated inference readme (#5343) 2024-02-02 14:31:10 +08:00
Yuanheng Zhao 5f98a9d68a
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
2024-01-30 16:06:09 +08:00
Yuanheng Zhao 3da9993b0d
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest

* revise and add triton context attn benchmark
2024-01-23 17:16:02 +08:00
Jianghai 9e2342bde2
[Hotfix] Fix bugs in testing continuous batching (#5270)
* fix bug

* fix bugs

* fix bugs

* fix bugs and add padding

* add funcs and fix bugs

* fix typos

* fix bugs

* add func
2024-01-18 16:31:14 +08:00
yuehuayingxueluo 86b63f720c
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
2024-01-17 16:03:10 +08:00
Yuanheng Zhao fa85e02b3b
[kernel] Add KV cache copy kernel during decoding (#5261)
* add kv copy triton kernel during decoding stage

* add pytest and fix kernel

* fix test utilities

* revise kernel config

* add benchmark for kvcache copy
2024-01-15 17:37:20 +08:00
FrankLeeeee 1ded7e81ef [git] fixed rebased files 2024-01-11 13:50:45 +00:00
yuehuayingxueluo fab294c7f4 fix CI bugs 2024-01-11 13:46:14 +00:00
yuehuayingxueluo 2a73e828eb fix bugs related to processing padding mask 2024-01-11 13:46:14 +00:00
Jianghai e545a871b8 [Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229)
* fix accuracy

* alignment in attention

* fix attention

* fix

* fix bugs

* fix bugs

* fix bugs
2024-01-11 13:46:14 +00:00
yuehuayingxueluo 47e53eaa1c fix bugs in attention.py and request_handler.py 2024-01-11 13:44:06 +00:00
Jianghai bfd9b1b494 [Inference] Pytorch Attention func, pad&nopad input support (#5219)
* add attn

* add attention test

* fix attn forward

* fix decoding
2024-01-11 13:44:06 +00:00