ColossalAI/colossalai/kernel/triton/flash_decoding.py

48 lines
1.8 KiB
Python

# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch
try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True
except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim)
if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
dtype=torch.float32,
device="cuda",
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)
mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
flash_decode_stage1(
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.block_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ,
)
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)