# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py import torch try: from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 HAS_LIGHTLLM_KERNEL = True except: print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") HAS_LIGHTLLM_KERNEL = False if HAS_LIGHTLLM_KERNEL: def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch calcu_shape1 = (batch_size, q_head_num, head_dim) if getattr(infer_state, 'mid_o', None) is None: infer_state.mid_o = torch.empty([batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda") infer_state.mid_o_logexpsum = torch.empty([batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda") mid_o = infer_state.mid_o mid_o_logexpsum = infer_state.mid_o_logexpsum flash_decode_stage1(q.view(calcu_shape1), cache_k, cache_v, infer_state.block_loc, infer_state.seq_len, infer_state.max_len_in_batch, mid_o, mid_o_logexpsum, BLOCK_SEQ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)