2023-09-11 17:22:56 +00:00
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
try :
import triton
import triton . language as tl
2023-09-19 06:20:26 +00:00
2023-09-11 17:22:56 +00:00
HAS_TRITON = True
except ImportError :
HAS_TRITON = False
print ( " please install triton from https://github.com/openai/triton " )
2023-10-19 14:22:47 +00:00
try :
2023-11-16 08:43:15 +00:00
from lightllm . models . llama . triton_kernel . token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
2023-10-19 14:22:47 +00:00
from lightllm . models . llama . triton_kernel . token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm . models . llama . triton_kernel . token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm . models . bloom . triton_kernel . token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
HAS_TRITON_TOKEN_ATTENTION = True
except ImportError :
print ( " unable to import lightllm kernels " )
HAS_TRITON_TOKEN_ATTENTION = False
2023-09-11 17:22:56 +00:00
2023-10-19 14:22:47 +00:00
if HAS_TRITON :
2023-09-11 17:22:56 +00:00
@torch.no_grad ( )
2023-10-19 14:22:47 +00:00
def token_attention_fwd (
q , k , v , attn_out , kv_cache_loc , kv_cache_start_loc , kv_cache_seq_len , max_len_in_batch , alibi = None
2023-09-19 06:20:26 +00:00
) :
2023-10-19 14:22:47 +00:00
head_num = k . shape [ 1 ]
batch_size = kv_cache_seq_len . shape [ 0 ]
calcu_shape1 = ( batch_size , head_num , k . shape [ 2 ] )
total_token_num = k . shape [ 0 ]
2023-09-11 17:22:56 +00:00
2023-10-19 14:22:47 +00:00
att_m_tensor = torch . empty ( ( head_num , total_token_num ) , dtype = q . dtype , device = " cuda " )
2023-09-11 17:22:56 +00:00
2023-10-19 14:22:47 +00:00
if alibi is None :
lightllm_llama_token_att_fwd (
q . view ( calcu_shape1 ) ,
2023-09-11 17:22:56 +00:00
k ,
2023-10-19 14:22:47 +00:00
att_m_tensor ,
2023-09-11 17:22:56 +00:00
kv_cache_loc ,
kv_cache_start_loc ,
2023-10-19 14:22:47 +00:00
kv_cache_seq_len ,
max_len_in_batch ,
2023-09-11 17:22:56 +00:00
)
else :
2023-10-19 14:22:47 +00:00
lightllm_bloom_token_att_fwd (
q . view ( calcu_shape1 ) ,
2023-09-11 17:22:56 +00:00
k ,
2023-10-19 14:22:47 +00:00
att_m_tensor ,
alibi ,
2023-09-11 17:22:56 +00:00
kv_cache_loc ,
kv_cache_start_loc ,
2023-10-19 14:22:47 +00:00
kv_cache_seq_len ,
max_len_in_batch ,
2023-09-19 06:20:26 +00:00
)
2023-09-11 17:22:56 +00:00
prob = torch . empty_like ( att_m_tensor )
2023-10-19 14:22:47 +00:00
lightllm_llama_token_softmax_fwd ( att_m_tensor , kv_cache_start_loc , kv_cache_seq_len , prob , max_len_in_batch )
2023-09-11 17:22:56 +00:00
att_m_tensor = None
2023-11-16 08:43:15 +00:00
lightllm_llama_token_att_fwd2 (
2023-09-19 06:20:26 +00:00
prob , v , attn_out . view ( calcu_shape1 ) , kv_cache_loc , kv_cache_start_loc , kv_cache_seq_len , max_len_in_batch
)
2023-09-11 17:22:56 +00:00
prob = None
return
2023-09-22 03:12:50 +00:00
class Llama2TokenAttentionForwards :
@staticmethod
@triton.jit
2023-10-19 14:22:47 +00:00
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8
2023-09-22 03:12:50 +00:00
def _fwd_kernel (
Logics ,
V ,
Out ,
B_Loc ,
B_Start_Loc ,
B_Seqlen ,
max_input_len ,
stride_logic_h ,
stride_logic_bs ,
stride_vbs ,
stride_vh ,
stride_vd ,
stride_obs ,
stride_oh ,
stride_od ,
stride_b_loc_b ,
stride_b_loc_s ,
other_kv_index , # avoid nan information
kv_group_num ,
BLOCK_DMODEL : tl . constexpr ,
BLOCK_N : tl . constexpr ,
) :
cur_batch = tl . program_id ( 0 )
cur_head = tl . program_id ( 1 )
cur_kv_head = cur_head / / kv_group_num
cur_batch_seq_len = tl . load ( B_Seqlen + cur_batch )
cur_batch_start_loc = tl . load ( B_Start_Loc + cur_batch )
offs_n = tl . arange ( 0 , BLOCK_N )
offs_d = tl . arange ( 0 , BLOCK_DMODEL )
off_v = cur_kv_head * stride_vh + offs_d [ None , : ] * stride_vd
off_b_loc = cur_batch * stride_b_loc_b + ( max_input_len - cur_batch_seq_len ) * stride_b_loc_s
v_ptrs = V + off_v
e_max = float ( " -inf " )
e_sum = 0.0
acc = tl . zeros ( [ BLOCK_DMODEL ] , dtype = tl . float32 )
for start_n in range ( 0 , cur_batch_seq_len , BLOCK_N ) :
start_n = tl . multiple_of ( start_n , BLOCK_N )
v_index = tl . load (
B_Loc + off_b_loc + ( start_n + offs_n ) * stride_b_loc_s ,
mask = ( start_n + offs_n ) < cur_batch_seq_len ,
other = other_kv_index ,
)
qk = tl . load (
Logics + cur_head * stride_logic_h + ( cur_batch_start_loc + start_n + offs_n ) * stride_logic_bs ,
mask = start_n + offs_n < cur_batch_seq_len ,
other = float ( " -inf " ) ,
)
n_e_max = tl . maximum ( tl . max ( qk , 0 ) , e_max )
old_scale = tl . exp ( e_max - n_e_max )
p = tl . exp ( qk - n_e_max )
e_sum = e_sum * old_scale + tl . sum ( p , 0 )
v = tl . load ( v_ptrs + v_index [ : , None ] * stride_vbs )
acc = acc * old_scale + tl . sum ( p [ : , None ] * v , 0 )
e_max = n_e_max
acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl . store ( out_ptrs , acc )
return
2023-10-19 14:22:47 +00:00
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36
2023-09-22 03:12:50 +00:00
@staticmethod
@torch.no_grad ( )
def token_softmax_reducev_fwd ( logics , v , o , b_loc , b_start_loc , b_seq_len , max_input_len , other_kv_index ) :
BLOCK = 64
batch , head = b_seq_len . shape [ 0 ] , logics . shape [ 0 ]
grid = ( batch , head )
kv_group_num = logics . shape [ 0 ] / / v . shape [ 1 ]
num_warps = 1
Llama2TokenAttentionForwards . _fwd_kernel [ grid ] (
logics ,
v ,
o ,
b_loc ,
b_start_loc ,
b_seq_len ,
max_input_len ,
logics . stride ( 0 ) ,
logics . stride ( 1 ) ,
v . stride ( 0 ) ,
v . stride ( 1 ) ,
v . stride ( 2 ) ,
o . stride ( 0 ) ,
o . stride ( 1 ) ,
o . stride ( 2 ) ,
b_loc . stride ( 0 ) ,
b_loc . stride ( 1 ) ,
other_kv_index ,
kv_group_num ,
BLOCK_DMODEL = v . shape [ - 1 ] ,
BLOCK_N = BLOCK ,
num_warps = num_warps ,
num_stages = 3 ,
)
return
# this is the interface of llama2 attn forward
@staticmethod
@torch.no_grad ( )
def token_attn (
q , k , v , attn_out , kv_cache_loc , kv_cache_start_loc , kv_cache_seq_len , max_len_in_batch , other_kv_index
) :
total_token_num = k . shape [ 0 ]
batch_size , head_num , head_dim = q . shape
calcu_shape1 = ( batch_size , head_num , head_dim )
att_m_tensor = torch . empty ( ( head_num , total_token_num ) , dtype = q . dtype , device = " cuda " )
2023-11-16 08:43:15 +00:00
lightllm_llama_token_att_fwd (
2023-09-22 03:12:50 +00:00
q ,
k ,
att_m_tensor ,
kv_cache_loc ,
kv_cache_start_loc ,
kv_cache_seq_len ,
max_len_in_batch ,
)
if triton . __version__ == " 2.0.0 " :
prob = torch . empty_like ( att_m_tensor )
2023-11-16 08:43:15 +00:00
lightllm_llama_token_softmax_fwd (
2023-09-22 03:12:50 +00:00
att_m_tensor , kv_cache_start_loc , kv_cache_seq_len , prob , max_len_in_batch
)
att_m_tensor = None
2023-11-16 08:43:15 +00:00
lightllm_llama_token_att_fwd2 (
2023-09-22 03:12:50 +00:00
prob ,
v ,
attn_out . view ( calcu_shape1 ) ,
kv_cache_loc ,
kv_cache_start_loc ,
kv_cache_seq_len ,
max_len_in_batch ,
)
prob = None
return
elif triton . __version__ > = " 2.1.0 " :
Llama2TokenAttentionForwards . token_softmax_reducev_fwd (
att_m_tensor ,
v ,
attn_out . view ( calcu_shape1 ) ,
kv_cache_loc ,
kv_cache_start_loc ,
kv_cache_seq_len ,
max_len_in_batch ,
other_kv_index ,
)
else :
raise Exception ( " not support triton version " )