mirror of https://github.com/hpcaitech/ColossalAI
[Kernels]added flash-decoidng of triton (#5063)
* added flash-decoidng of triton based on lightllm kernel * add req * clean * clean * delete build.sh --------- Co-authored-by: cuiqing.li <lixx336@gmail.com>pull/5070/head
parent
fd6482ad8c
commit
bce919708f
|
@ -1,24 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
# install triton
|
|
||||||
pip install triton
|
|
||||||
pip install transformers
|
|
||||||
|
|
||||||
# install lightllm and flash-attention
|
|
||||||
mkdir 3rdParty
|
|
||||||
cd 3rdParty
|
|
||||||
git clone https://github.com/ModelTC/lightllm
|
|
||||||
cd lightllm
|
|
||||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
|
||||||
pip install -e .
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
git clone -recursive https://github.com/Dao-AILab/flash-attention
|
|
||||||
cd flash-attention
|
|
||||||
pip install -e .
|
|
||||||
|
|
||||||
cd ../../
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,14 @@ except:
|
||||||
HAS_LIGHTLLM_KERNEL = False
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_with_kvcache
|
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
|
||||||
|
HAS_TRITON_FLASH_DECODING_KERNEL = True
|
||||||
|
except:
|
||||||
|
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||||
|
HAS_TRITON_FLASH_DECODING_KERNEL = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_with_kvcache
|
||||||
HAS_FLASH_KERNEL = True
|
HAS_FLASH_KERNEL = True
|
||||||
except:
|
except:
|
||||||
HAS_FLASH_KERNEL = False
|
HAS_FLASH_KERNEL = False
|
||||||
|
@ -42,7 +48,6 @@ def rotate_half(x):
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
@ -67,7 +72,6 @@ def llama_triton_context_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -78,7 +82,6 @@ def llama_triton_context_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -90,13 +93,20 @@ def llama_triton_context_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
|
||||||
|
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
|
||||||
|
token_flash_decoding(q = query_states,
|
||||||
|
o_tensor = attn_output,
|
||||||
|
infer_state = infer_state,
|
||||||
|
q_head_num = q_head_num,
|
||||||
|
head_dim = head_dim,
|
||||||
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
|
||||||
|
return
|
||||||
|
|
||||||
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
|
|
||||||
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
|
|
||||||
if num_key_value_groups == 1:
|
if num_key_value_groups == 1:
|
||||||
token_attention_fwd(
|
token_attention_fwd(
|
||||||
query_states,
|
query_states,
|
||||||
|
@ -106,7 +116,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
|
||||||
infer_state.block_loc,
|
infer_state.block_loc,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -118,7 +127,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
|
||||||
infer_state.block_loc,
|
infer_state.block_loc,
|
||||||
infer_state.start_loc,
|
infer_state.start_loc,
|
||||||
infer_state.seq_len,
|
infer_state.seq_len,
|
||||||
# infer_state.cache_manager.past_key_values_length,
|
|
||||||
infer_state.max_len_in_batch,
|
infer_state.max_len_in_batch,
|
||||||
infer_state.other_kv_index,
|
infer_state.other_kv_index,
|
||||||
)
|
)
|
||||||
|
@ -451,10 +459,14 @@ class LlamaInferenceForwards:
|
||||||
)
|
)
|
||||||
|
|
||||||
if HAS_LIGHTLLM_KERNEL:
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
|
|
||||||
attn_output = torch.empty_like(query_states)
|
attn_output = torch.empty_like(query_states)
|
||||||
llama_triton_token_attention(
|
llama_triton_token_attention(query_states = query_states,
|
||||||
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
|
attn_output = attn_output,
|
||||||
)
|
infer_state = infer_state,
|
||||||
|
num_key_value_groups = self.num_key_value_groups,
|
||||||
|
q_head_num = q_len * self.num_heads,
|
||||||
|
head_dim = self.head_dim)
|
||||||
else:
|
else:
|
||||||
self.num_heads // self.num_key_value_heads
|
self.num_heads // self.num_key_value_heads
|
||||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
|
||||||
|
|
|
@ -137,6 +137,7 @@ if HAS_TRITON:
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _context_flash_attention_kernel_2(
|
def _context_flash_attention_kernel_2(
|
||||||
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
||||||
|
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_LIGHTLLM_KERNEL:
|
||||||
|
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
||||||
|
BLOCK_SEQ = 256
|
||||||
|
batch_size = infer_state.batch_size
|
||||||
|
max_len_in_batch = infer_state.max_len_in_batch
|
||||||
|
|
||||||
|
|
||||||
|
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
||||||
|
|
||||||
|
if getattr(infer_state, 'mid_o', None) is None:
|
||||||
|
infer_state.mid_o = torch.empty([batch_size,
|
||||||
|
q_head_num,
|
||||||
|
max_len_in_batch // BLOCK_SEQ + 1,
|
||||||
|
head_dim],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda")
|
||||||
|
infer_state.mid_o_logexpsum = torch.empty([batch_size,
|
||||||
|
q_head_num,
|
||||||
|
max_len_in_batch // BLOCK_SEQ + 1],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
mid_o = infer_state.mid_o
|
||||||
|
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
||||||
|
|
||||||
|
flash_decode_stage1(q.view(calcu_shape1),
|
||||||
|
cache_k,
|
||||||
|
cache_v,
|
||||||
|
infer_state.block_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
infer_state.max_len_in_batch,
|
||||||
|
mid_o,
|
||||||
|
mid_o_logexpsum,
|
||||||
|
BLOCK_SEQ)
|
||||||
|
flash_decode_stage2(mid_o,
|
||||||
|
mid_o_logexpsum,
|
||||||
|
infer_state.seq_len,
|
||||||
|
o_tensor.view(calcu_shape1),
|
||||||
|
BLOCK_SEQ)
|
|
@ -75,11 +75,11 @@ def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Tensor parallel size")
|
parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Maximum batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size")
|
||||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length")
|
||||||
parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
|
parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length")
|
||||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size")
|
parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -2,6 +2,6 @@ transformers==4.34.0
|
||||||
packaging
|
packaging
|
||||||
ninja
|
ninja
|
||||||
auto-gptq==0.5.0
|
auto-gptq==0.5.0
|
||||||
git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039
|
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
||||||
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
||||||
|
|
Loading…
Reference in New Issue