From 1513f20f4d80f782fab381996368ff2c2f3c95c3 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Thu, 11 Jan 2024 18:06:39 +0800 Subject: [PATCH] [kernel] Add flash decoding triton kernel for blocked kv cache (#5249) * add flash decoding unpad triton kernel * rename flash decoding kernel * add kernel testing (draft) * revise pytest * support kv group (GQA) * (trivial) fix api and pytest * (trivial) func renaming * (trivial) func/file renaming * refactor pytest for attention * (trivial) format and consistent vars of context/decode attn * (trivial) remove test redundancy --- colossalai/kernel/triton/__init__.py | 2 + .../kernel/triton/context_attn_unpad.py | 88 +++--- colossalai/kernel/triton/flash_decoding.py | 279 ++++++++++++++++++ tests/test_infer_ops/triton/kernel_utils.py | 115 ++++++-- .../triton/test_context_attn_unpad.py | 130 +++----- .../triton/test_decoding_attn.py | 115 ++++++++ 6 files changed, 576 insertions(+), 153 deletions(-) create mode 100644 colossalai/kernel/triton/flash_decoding.py create mode 100644 tests/test_infer_ops/triton/test_decoding_attn.py diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index f5f530c92..4ac71ac64 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,6 +9,7 @@ except ImportError: # There may exist import error even if we have triton installed. if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded + from .flash_decoding import flash_decoding_fwd from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton from .no_pad_rotary_embedding import rotary_embedding @@ -16,6 +17,7 @@ if HAS_TRITON: __all__ = [ "context_attention_unpadded", + "flash_decoding_fwd", "softmax", "layer_norm", "gptq_fused_linear_triton", diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index e4e09302e..64efa3491 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel( sm_scale, KV_GROUPS: tl.constexpr, BLOCK_SIZE: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel( for i in range(0, cur_seq_idx): prev_seq_len_sum += tl.load(context_lengths + i) - q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh - kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(cur_seq_len, BLOCK_DMODEL), + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), strides=(stride_qt, stride_qd), offsets=(block_start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( - base=K + kv_offset, - shape=(BLOCK_DMODEL, cur_seq_len), + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), strides=(stride_kd, stride_kt), offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), + block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1), ) V_block_ptr = tl.make_block_ptr( - base=V + kv_offset, - shape=(cur_seq_len, BLOCK_DMODEL), + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), strides=(stride_vt, stride_vd), offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), + block_shape=(BLOCK_N, HEAD_DIM), order=(1, 0), ) O_block_ptr = tl.make_block_ptr( - base=O + q_offset, - shape=(cur_seq_len, BLOCK_DMODEL), + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), strides=(stride_ot, stride_od), offsets=(block_start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) @@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel( # as we have BLOCK_M the same size as the block size. cur_block_table_idx = block_start_m cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) - kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) offsets_n = tl.arange(0, BLOCK_N) m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) if block_start_m * BLOCK_M >= cur_seq_len: return @@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel( if cur_head_idx % KV_GROUPS == 0: # Copy k to corresponding cache block - kd_offsets = tl.arange(0, BLOCK_DMODEL) - kt_offsets = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) - k_offsets = K + kv_offset + kd_offsets[:, None] * stride_kd + kt_offsets[None, :] * stride_kt - k = tl.load(k_offsets, mask=kt_offsets[None, :] < cur_seq_len, other=0.0) - kcached_offsets = tl.arange(0, BLOCK_DMODEL) - kcachebs_offsets = tl.arange(0, BLOCK_SIZE) - kcache_offsets = ( + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt + k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0) + offsets_kcachebs = tl.arange(0, BLOCK_SIZE) + offsets_kcache = ( KCache - + kvcache_offset - + kcached_offsets[:, None] * stride_cached - + kcachebs_offsets[None, :] * stride_cachebs + + offset_kvcache + + offsets_dmodel[:, None] * stride_cached + + offsets_kcachebs[None, :] * stride_cachebs ) - tl.store(kcache_offsets, k, mask=kcachebs_offsets[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) # Copy v to corresponding cache block - vd_offsets = kd_offsets - vt_offsets = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) - v_offsets = V + kv_offset + vt_offsets[:, None] * stride_vt + vd_offsets[None, :] * stride_vd - v = tl.load(v_offsets, mask=vt_offsets[:, None] < cur_seq_len, other=0.0) - vcached_offsets = kcached_offsets - vcachebs_offsets = kcachebs_offsets - vcache_offsets = ( + offsets_vd = offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) + offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0) + offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here + offsets_vcache = ( VCache - + kvcache_offset - + vcachebs_offsets[:, None] * stride_cachebs - + vcached_offsets[None, :] * stride_cached + + offset_kvcache + + offsets_vcachebs[:, None] * stride_cachebs + + offsets_dmodel[None, :] * stride_cached ) - tl.store(vcache_offsets, v, mask=vcachebs_offsets[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) return def context_attention_unpadded( - q: torch.Tensor, # [num_tokens, num_heads, head_size] - k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] - v: torch.Tensor, # [num_tokens, num_kv_heads, head_size] - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size] - v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size] + q: torch.Tensor, # [num_tokens, num_heads, head_dim] + k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] + v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_size: int, @@ -254,7 +252,7 @@ def context_attention_unpadded( sm_scale, num_kv_group, block_size, - BLOCK_DMODEL=Lk, + HEAD_DIM=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py new file mode 100644 index 000000000..ed1629e96 --- /dev/null +++ b/colossalai/kernel/triton/flash_decoding.py @@ -0,0 +1,279 @@ +# Applying Flash-Decoding as descibed in +# https://pytorch.org/blog/flash-decoding/ +# by Tri Dao, 2023 +import torch +import triton +import triton.language as tl + + +# Triton 2.1.0 +@triton.jit +def _flash_decoding_fwd_kernel( + Q, # [batch_size, head_num, head_dim] + KCache, # [num_blocks, num_kv_heads, head_dim, block_size] + VCache, # [num_blocks, num_kv_heads, head_dim, block_size] + block_tables, # [batch_size, max_blocks_per_sequence] + mid_o, # [batch_size, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size, head_num, kv_split_num] + context_lengths, # [batch_size] + stride_qt, + stride_qh, + stride_qd, + stride_cacheb, + stride_cacheh, + stride_cached, + stride_cachebs, + stride_bts, + stride_btb, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_o_lset, + stride_mid_o_lseh, + stride_mid_o_lseb, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + cur_head_idx = tl.program_id(1) + block_start_kv = tl.program_id(2) # for splitting k/v + + cur_kv_head_idx = cur_head_idx // KV_GROUPS + offsets_dmodel = tl.arange(0, HEAD_DIM) + + # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same + # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) + # and then support calculating multiple kv cache blocks on an instance + tl.static_assert(BLOCK_KV == BLOCK_SIZE) + + # get the current (kv) sequence length from provided context lengths tensor + cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + + offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + q = tl.load(Q + offsets_q) + + # block table for the current sequence + block_table_ptr = block_tables + cur_seq_idx * stride_bts + + # actually current block table current block start idx + # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) + cur_bt_start_idx = block_start_kv + cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) + + if block_start_kv * BLOCK_KV >= cur_kv_seq_len: + # TODO might want to remove if-else block? + return + + cur_occupied_size = tl.where( + (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE + ) + tl.device_assert(cur_occupied_size >= 0) + + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + K_block_ptr = tl.make_block_ptr( + base=KCache + offset_kvcache, + shape=(HEAD_DIM, cur_occupied_size), + strides=(stride_cached, stride_cachebs), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_SIZE), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kvcache, + shape=(HEAD_DIM, cur_occupied_size), + strides=(stride_cached, stride_cachebs), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_SIZE), + order=(0, 1), + ) + k_cur_block = tl.load(K_block_ptr) + v_cur_block = tl.load(V_block_ptr) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + # use block size of the paged/blocked kv cache + S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, + # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. + # Refer to https://github.com/openai/triton/discussions/895 + S_ij += tl.sum(q[:, None] * k_cur_block, 0) + S_ij *= sm_scale + S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) + + m = tl.max(S_ij, 0) + S_ij -= m + p_ij_hat = tl.exp(S_ij) + l = tl.sum(p_ij_hat, 0) + p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) + acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1) + acc = acc / l + + offsets_mid_o = ( + cur_seq_idx * stride_mid_ot + + cur_head_idx * stride_mid_oh + + block_start_kv * stride_mid_ob + + offsets_dmodel * stride_mid_od + ) + tl.store(mid_o + offsets_mid_o, acc) + offsets_mid_o_lse = ( + cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + ) + # logsumexp L^(j) = m^(j) + log(l^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) + + +# Triton 2.1.0 +@triton.jit +def _flash_decoding_fwd_reduce_kernel( + mid_o, # [batch_size, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size, head_num, kv_split_num] + O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] + context_lengths, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_o_lset, + stride_o_lseh, + stride_o_lseb, + stride_ob, + stride_oh, + stride_od, + BLOCK_KV: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + cur_head_idx = tl.program_id(1) + + cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + offsets_dmodel = tl.arange(0, HEAD_DIM) + + # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have + # BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted. + kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV + m_i = float("-inf") # max logic + l = 0.0 # sum exp + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel + offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh + for block_i in range(0, kv_split_num, 1): + mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob) + lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb) + m_ij = tl.maximum(m_i, lse) + scale = tl.exp(m_i - m_ij) + acc = acc * scale + lse -= m_ij + exp_logic = tl.exp(lse) + acc += exp_logic * mid_o_block + l = scale * l + exp_logic + m_i = m_ij + + acc = acc / l + offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel + tl.store(O + offsets_O, acc.to(O.type.element_ty)) + return + + +# Decoding Stage +# Used with blocked KV Cache (PagedAttention) +def flash_decoding_fwd( + q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] + v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] + context_lengths: torch.Tensor, # [batch_size] + block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence] + block_size: int, + num_kv_group: int = 1, +): + bsz, _, num_heads, head_dim = q.shape + + assert head_dim in {32, 64, 128, 256} + assert context_lengths.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, " + f"batch size {bsz}" + ) + assert k_cache.size(-1) == v_cache.size(-1) == block_size, ( + f"Got incompatible block size on kv caches:\n" + f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, " + f"v_cache block_size {v_cache.size(-1)}" + ) + # NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths. + bsz = context_lengths.size(0) # e.g. the number of seqs + max_seq_len = context_lengths.max().item() + sm_scale = 1.0 / (head_dim**0.5) + + # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v + # For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`) + assert block_size in {16, 32, 64, 128} + BLOCK_KV = block_size + + kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV + mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) + mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + + if q.dim() == 4: + assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}" + q = q.squeeze(1) + + grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV)) + _flash_decoding_fwd_kernel[grid]( + q, + k_cache, + v_cache, + block_tables, + mid_o, + mid_o_lse, + context_lengths, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + mid_o.stride(3), + mid_o_lse.stride(0), + mid_o_lse.stride(1), + mid_o_lse.stride(2), + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_KV=block_size, + BLOCK_SIZE=block_size, + HEAD_DIM=head_dim, + ) + + output = torch.zeros_like(q) + output = output.view(-1, output.size(-2), output.size(-1)) + + grid = (bsz, num_heads) + _flash_decoding_fwd_reduce_kernel[grid]( + mid_o, + mid_o_lse, + output, + context_lengths, + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + mid_o.stride(3), + mid_o_lse.stride(0), + mid_o_lse.stride(1), + mid_o_lse.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_KV=block_size, + HEAD_DIM=head_dim, + ) + + return output diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index 0732ace1e..2f34c5463 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -1,27 +1,102 @@ -import math - import torch from torch.nn import functional as F -def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): +# This function is adapted from src/transformers/models/llama/modeling_llama.py +# in huggingface transformers repository +# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273 +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim) """ - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - sm_scale = 1 / math.sqrt(head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale - scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) + if n_rep == 1: + return hidden_states + bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim) + return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output + +# Attention calculation adapted from HuggingFace transformers repository +# src/transformers/models/llama/modeling_llama.py +# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350 +def torch_attn_ref( + q: torch.Tensor, # [bsz, seq_len, num_heads, head_dim] + k: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim] + v: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim] + attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len] + bsz: int, + seq_len: int, + kv_seq_len: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, +): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim + q = q.view(bsz, seq_len, num_heads, head_dim) + k = k.view(bsz, kv_seq_len, num_kv_heads, head_dim) + v = v.view(bsz, kv_seq_len, num_kv_heads, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # repeat kv for GQA and MQA + # k/v won't change if kv_group_num is 1 + assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads" + kv_group_num = num_heads // num_kv_heads + k = repeat_kv(k, kv_group_num) + v = repeat_kv(v, kv_group_num) + + qk = torch.matmul(q, k.transpose(2, 3)) + attn_scores = qk / (head_dim**0.5) + + assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores" + # for left-side padding + if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_scores = attn_scores + attention_mask + attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) + out = torch.matmul(attn_weights, v) + if out.size() != (bsz, num_heads, seq_len, head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}" + ) + out = out.transpose(1, 2).contiguous() + return out + + +def mock_alloc_block_table_and_kvcache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +): + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + k_cache[block_id, :, :, :allocated_locs] = k_block + v_cache[block_id, :, :, :allocated_locs] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables diff --git a/tests/test_infer_ops/triton/test_context_attn_unpad.py b/tests/test_infer_ops/triton/test_context_attn_unpad.py index 8cca2af1a..60459a3c2 100644 --- a/tests/test_infer_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer_ops/triton/test_context_attn_unpad.py @@ -1,10 +1,10 @@ import pytest import torch -import torch.nn.functional as F from packaging import version from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device +from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref try: import triton # noqa @@ -17,60 +17,40 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -def torch_attn_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_len: int, num_heads: int, head_size: int): - # For a single sequence, q,k,v [seq_len, num_heads, head_size] - assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_size - q = q.view(seq_len, num_heads, head_size) - k = k.view(seq_len, num_heads, head_size) - v = v.view(seq_len, num_heads, head_size) - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - mask = torch.tril(torch.ones(1, seq_len, seq_len), diagonal=0).to(device=get_current_device()) - mask[mask == 0.0] = float("-inf") - mask = mask.repeat(num_heads, 1, 1) - - qk = torch.matmul(q, k.transpose(1, 2)) - attn_scores = qk / (head_size**0.5) - attn_weights = F.softmax(attn_scores.to(dtype=torch.float32) + mask, dim=-1).to(dtype=q.dtype) - out = torch.matmul(attn_weights, v).transpose(0, 1).contiguous() - out = out.reshape(-1, num_heads, head_size) - return out - - -def torch_attn_unpad(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor): - # Process sequence one by one and cat them together. - # q,k,v [num_tokens(sum(context_lengths)), num_heads, head_size] +def torch_attn_unpad( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int +): + # Process sequence one by one and concatenate them together. + # q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim] assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor" - _, num_heads, head_size = q.shape + + _, num_heads, head_dim = q.shape out_torch = [] start_idx = 0 - for i in range(len(context_lengths)): - end_idx = start_idx + context_lengths[i].item() + for seq_i in range(len(context_lengths)): + end_idx = start_idx + context_lengths[seq_i].item() + seq_len = end_idx - start_idx + mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device) + mask[mask == 0.0] = float("-inf") + torch_attn_ref_out = torch_attn_ref( - q[start_idx:end_idx], k[start_idx:end_idx], v[start_idx:end_idx], end_idx - start_idx, num_heads, head_size + q[start_idx:end_idx].unsqueeze(0), + k[start_idx:end_idx].unsqueeze(0), + v[start_idx:end_idx].unsqueeze(0), + mask, + 1, # set bsz as 1 as we're processing sequence one by one + seq_len, + seq_len, + num_heads, + num_kv_heads, + head_dim, ) - out_torch.append(torch_attn_ref_out) + out_torch.append(torch_attn_ref_out.squeeze(0)) start_idx = end_idx + return torch.cat(out_torch, dim=0) -# This method is adapted from src/transformers/models/llama/modeling_llama.py -# in transformers repository https://github.com/huggingface/transformers -# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273 -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (num_tokens, - num_key_value_heads, head_dim) to (num_tokens, num_attention_heads, head_dim) - """ - num_tokens, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :].expand(num_tokens, num_key_value_heads, n_rep, head_dim) - return hidden_states.reshape(num_tokens, num_key_value_heads * n_rep, head_dim) - - @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @pytest.mark.parametrize("bsz", [4, 7, 32]) @pytest.mark.parametrize("block_size", [16, 32, 64]) @@ -87,72 +67,46 @@ def test_context_attention( same_context_len: bool, ): torch.manual_seed(123) - - dtype = torch.float16 - device = get_current_device() - num_seqs = bsz - num_kv_heads = num_attn_heads // kv_group_num - assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." - head_size = 32 - max_seq_len = max_num_blocks_per_seq * block_size - # It's necessary to clear cache here. torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + head_dim = 32 + max_seq_len = max_num_blocks_per_seq * block_size + dtype = torch.float16 + device = get_current_device() + if same_context_len: - context_lengths = torch.tensor([max_seq_len for _ in range(num_seqs)], dtype=torch.int32, device=device) + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) else: - context_lengths = torch.randint(low=1, high=max_seq_len, size=(num_seqs,), dtype=torch.int32, device=device) + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) num_tokens = torch.sum(context_lengths).item() - qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_size) + qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_dim) qkv = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) q, k, v = torch.split(qkv, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) - cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_size, block_size) + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) k_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device) k_cache_triton = torch.zeros_like(k_cache_torch) v_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device) v_cache_triton = torch.zeros_like(v_cache_torch) # Mock allocation on block tables - block_id = 0 - block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) - num_tokens_processed = 0 - for i, seq_len in enumerate(context_lengths.tolist()): - right_bound = (seq_len + block_size - 1) // block_size # open bound - block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) - # Manually fill k_cache_torch and v_cache_torch by copying from k and v - for i in range(right_bound): - if i == right_bound - 1: - allocated_locs = seq_len % block_size or block_size - else: - allocated_locs = block_size - k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) - v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) - cur_block_size_occupied = k_block.shape[-1] - assert cur_block_size_occupied <= block_size, "Invalid occupied size of block during mock allocation" - k_cache_torch[block_id, :, :, :cur_block_size_occupied] = k_block - v_cache_torch[block_id, :, :, :cur_block_size_occupied] = v_block - - num_tokens_processed += allocated_locs - block_id += 1 - + block_tables = mock_alloc_block_table_and_kvcache( + k, v, k_cache_torch, v_cache_torch, context_lengths, bsz, max_num_blocks_per_seq, block_size + ) block_tables = block_tables.to(device=device) out_triton = context_attention_unpadded( q, k, v, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size ) - # For GQA and MQA, repeat k, v for torch attention calculation - # k/v won't change if provided `num_kv_group` is 1 - num_kv_group = num_attn_heads // num_kv_heads - k = repeat_kv(k, num_kv_group) - v = repeat_kv(v, num_kv_group) - out_torch = torch_attn_unpad(q, k, v, context_lengths) + out_torch = torch_attn_unpad(q, k, v, context_lengths, num_attn_heads, num_kv_heads) assert out_torch.shape == out_triton.shape - assert torch.allclose(out_torch, out_triton, atol=1e-2, rtol=1e-3) + assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) assert torch.allclose(k_cache_torch, k_cache_triton) assert torch.allclose(v_cache_torch, v_cache_triton) diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py new file mode 100644 index 000000000..58b8fe0cd --- /dev/null +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -0,0 +1,115 @@ +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton import flash_decoding_fwd +from colossalai.utils import get_current_device +from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_decoding(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor): + assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor" + assert q.size(1) == 1, "Only used for decoding" + assert k.shape == v.shape + + bsz, _, num_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + assert num_heads % num_kv_heads == 0, "Invalid kv heads and attention heads." + padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=q.device) + for i in range(bsz): + cur_seq_len = context_lengths[i].item() + assert cur_seq_len <= kv_seq_len + padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf") + + out = torch_attn_ref(q, k, v, padding_mask, bsz, 1, kv_seq_len, num_heads, num_kv_heads, head_dim) + return out + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_attn_heads", [16]) +@pytest.mark.parametrize("kv_group_num", [1, 2, 16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_flash_decoding( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_attn_heads: int, + kv_group_num: int, + same_context_len: bool, +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + q_len = 1 + head_dim = 128 + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(context_lengths).item() + + q_size = (bsz, q_len, num_attn_heads, head_dim) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + kv_size = (num_tokens, 2 * num_kv_heads, head_dim) + kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2) + + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache( + k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size + ) + block_tables = block_tables.to(device=device) + + q = q.view(bsz, q_len, num_attn_heads, head_dim) + out_triton = flash_decoding_fwd( + q, + k_cache, + v_cache, + context_lengths, + block_tables, + block_size, + kv_group_num, + ) + out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim] + + # rebuild (batched) kv with padding for torch attention + # q [bsz, 1, num_heads, head_dim] + # k/v [num_tokens, num_kv_heads, head_dim] + max_seq_len = context_lengths.max().item() + k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device) + v_torch = torch.zeros_like(k_torch) + prev_len_sum = 0 + for i, seq_len in enumerate(context_lengths.tolist()): + # mock left-side padding + k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len] + v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len] + prev_len_sum += seq_len + # k/v [bsz, max_seq_len, num_kv_heads, head_dim] + out_torch = torch_decoding(q, k_torch, v_torch, context_lengths) + + assert out_torch.shape == out_triton.shape + assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)