Commit Graph

8 Commits (e8f0642f2841f6aeb6ed0e6695ff9d9ef14f198b)

Author SHA1 Message Date
yuehuayingxueluo 4f28cb43c0
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
2024-01-26 14:00:10 +08:00
Yuanheng Zhao af8359c430
[hotfix] fix boundary check in batch (#5306) 2024-01-25 10:23:12 +08:00
Yuanheng Zhao 3da9993b0d
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest

* revise and add triton context attn benchmark
2024-01-23 17:16:02 +08:00
yuehuayingxueluo bfff9254ac
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
2024-01-22 10:55:34 +08:00
Yuanheng Zhao 6e487e7d3c
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors

* add singleton class holding intermediate values

* fix triton kernel api

* add benchmark in pytest

* fix kernel api and add benchmark

* revise flash decoding triton kernel in/out shapes

* fix calling of triton kernel in modeling

* fix pytest: extract to util functions
2024-01-19 15:47:16 +08:00
Yuanheng Zhao 1513f20f4d [kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
* add flash decoding unpad triton kernel

* rename flash decoding kernel

* add kernel testing (draft)

* revise pytest

* support kv group (GQA)

* (trivial) fix api and pytest

* (trivial) func renaming

* (trivial) func/file renaming

* refactor pytest for attention

* (trivial) format and consistent vars of context/decode attn

* (trivial) remove test redundancy
2024-01-11 13:46:14 +00:00
Yuanheng Zhao 2bb92243d4 [Inference/NFC] Clean outdated inference tests and deprecated kernels (#5159)
* [inference/nfc] remove outdated inference tests

* remove outdated kernel tests

* remove deprecated triton kernels

* remove imports from deprecated kernels
2024-01-11 13:39:29 +00:00
Cuiqing Li (李崔卿) bce919708f
[Kernels]added flash-decoidng of triton (#5063)
* added flash-decoidng of triton based on lightllm kernel

* add req

* clean

* clean

* delete build.sh

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
2023-11-20 13:58:29 +08:00