From bd38fe6b912379080673a43d77fd3bdf0e5c852e Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 21 May 2024 22:12:15 +0800 Subject: [PATCH] [NFC] Fix code factors on inference triton kernels (#5743) --- colossalai/kernel/triton/flash_decoding.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 2fb8231cc..0012f8ec9 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -111,10 +111,10 @@ def _flash_decoding_fwd_kernel( m = tl.max(S_ij, 0) S_ij -= m p_ij_hat = tl.exp(S_ij) - l = tl.sum(p_ij_hat, 0) + l_i = tl.sum(p_ij_hat, 0) p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) - acc = acc / l + acc = acc / l_i offsets_mid_o = ( cur_token_idx * stride_mid_ot @@ -126,8 +126,8 @@ def _flash_decoding_fwd_kernel( offsets_mid_o_lse = ( cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb ) - # logsumexp L^(j) = m^(j) + log(l^(j)) - tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) + # logsumexp l_i^(j) = m^(j) + log(l_i^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i)) # Triton 2.1.0 @@ -234,10 +234,10 @@ def _alibi_flash_decoding_fwd_kernel( m = tl.max(S_ij, 0) S_ij -= m p_ij_hat = tl.exp(S_ij) - l = tl.sum(p_ij_hat, 0) + l_i = tl.sum(p_ij_hat, 0) p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) - acc = acc / l + acc = acc / l_i offsets_mid_o = ( cur_token_idx * stride_mid_ot @@ -249,8 +249,8 @@ def _alibi_flash_decoding_fwd_kernel( offsets_mid_o_lse = ( cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb ) - # logsumexp L^(j) = m^(j) + log(l^(j)) - tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) + # logsumexp l_i^(j) = m^(j) + log(l_i^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i)) # Triton 2.1.0 @@ -290,7 +290,7 @@ def _flash_decoding_fwd_reduce_kernel( # BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted. kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV m_i = float("-inf") # max logic - l = 0.0 # sum exp + l_i = 0.0 # sum exp acc = tl.zeros([HEAD_DIM], dtype=tl.float32) offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel @@ -304,10 +304,10 @@ def _flash_decoding_fwd_reduce_kernel( lse -= m_ij exp_logic = tl.exp(lse) acc += exp_logic * mid_o_block - l = scale * l + exp_logic + l_i = scale * l_i + exp_logic m_i = m_ij - acc = acc / l + acc = acc / l_i offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel tl.store(O + offsets_O, acc.to(O.type.element_ty)) return