mirror of https://github.com/hpcaitech/ColossalAI
51 lines
2.3 KiB
Python
51 lines
2.3 KiB
Python
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
|
import torch
|
|
try:
|
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
|
HAS_LIGHTLLM_KERNEL = True
|
|
except:
|
|
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
|
HAS_LIGHTLLM_KERNEL = False
|
|
|
|
|
|
if HAS_LIGHTLLM_KERNEL:
|
|
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
|
BLOCK_SEQ = 256
|
|
batch_size = infer_state.batch_size
|
|
max_len_in_batch = infer_state.max_len_in_batch
|
|
|
|
|
|
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
|
|
|
if getattr(infer_state, 'mid_o', None) is None:
|
|
infer_state.mid_o = torch.empty([batch_size,
|
|
q_head_num,
|
|
max_len_in_batch // BLOCK_SEQ + 1,
|
|
head_dim],
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
infer_state.mid_o_logexpsum = torch.empty([batch_size,
|
|
q_head_num,
|
|
max_len_in_batch // BLOCK_SEQ + 1],
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
|
|
mid_o = infer_state.mid_o
|
|
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
|
|
|
flash_decode_stage1(q.view(calcu_shape1),
|
|
cache_k,
|
|
cache_v,
|
|
infer_state.block_loc,
|
|
infer_state.seq_len,
|
|
infer_state.max_len_in_batch,
|
|
mid_o,
|
|
mid_o_logexpsum,
|
|
BLOCK_SEQ)
|
|
flash_decode_stage2(mid_o,
|
|
mid_o_logexpsum,
|
|
infer_state.seq_len,
|
|
o_tensor.view(calcu_shape1),
|
|
BLOCK_SEQ)
|