diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 5c799897a..89bd40b40 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import triton import triton.language as tl @@ -126,12 +128,161 @@ def rotary_embedding_kernel( ) +@triton.jit +def fused_rotary_embedding_kernel( + q, + k, + cos, + sin, + kv_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cacheb_stride, + cacheh_stride, + cachebs_stride, + cached_stride, + bts_stride, + btb_stride, + block_size, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + K_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, +): + block_head_index = tl.program_id(0) + block_token_index = tl.program_id(1) + + tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_q1 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + off_k0 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + loaded_q0 = tl.load( + q + off_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + loaded_q1 = tl.load( + q + off_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k0 = tl.load( + k + off_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k1 = tl.load( + k + off_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + + out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] + out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] + + out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] + out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride) + offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride + + kv_range0 = ( + block_ids[:, None, None, None] * cacheb_stride + + head_range[None, :, None, None] * cacheh_stride + + offsets_in_last_block[:, None, None, None] + + dim_range0[None, None, None, :] * cached_stride + ) + kv_range1 = ( + block_ids[:, None, None, None] * cacheb_stride + + head_range[None, :, None, None] * cacheh_stride + + offsets_in_last_block[:, None, None, None] + + dim_range1[None, None, None, :] * cached_stride + ) + + tl.store( + kv_cache + kv_range0, + out_k0[:, :, None, :], + ) + tl.store( + kv_cache + kv_range1, + out_k1[:, :, None, :], + ) + + # concat + tl.store( + q + off_q0, + out_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + q + off_q1, + out_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k0, + out_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k1, + out_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + + @torch.no_grad() def rotary_embedding( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + k_cache: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + kv_lengths: Optional[torch.Tensor] = None, ): """ Args: @@ -139,7 +290,9 @@ def rotary_embedding( k: key tensor, [total_tokens, head_num, head_dim] cos: cosine for rotary embedding, [max_position_len, head_dim] sin: sine for rotary embedding, [max_position_len, head_dim] - lengths [num_seqs] + k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] + kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] + block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) @@ -165,26 +318,56 @@ def rotary_embedding( cos_token_stride = cos.stride(0) cos_stride = cos.stride(1) - - rotary_embedding_kernel[grid]( - q, - k, - cos, - sin, - q_token_stride, - q_head_stride, - k_token_stride, - k_head_stride, - head_dim_stride, - cos_token_stride, - cos_stride, - q_total_tokens, - Q_HEAD_NUM=q_head_num, - K_HEAD_NUM=k_head_num, - HEAD_DIM=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_TOKENS=BLOCK_TOKENS, - num_warps=num_warps, - ) - + if k_cache == None: + rotary_embedding_kernel[grid]( + q, + k, + cos, + sin, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + q_total_tokens, + Q_HEAD_NUM=q_head_num, + K_HEAD_NUM=k_head_num, + HEAD_DIM=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_TOKENS=BLOCK_TOKENS, + num_warps=num_warps, + ) + else: + fused_rotary_embedding_kernel[grid]( + q, + k, + cos, + sin, + k_cache, + block_tables, + kv_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + k_cache.size(-2), + q_total_tokens, + Q_HEAD_NUM=q_head_num, + K_HEAD_NUM=k_head_num, + HEAD_DIM=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_TOKENS=BLOCK_TOKENS, + num_warps=num_warps, + ) return diff --git a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py index d611234f0..529c9fb2f 100644 --- a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py @@ -4,6 +4,7 @@ from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from colossalai.kernel.triton import rotary_embedding +from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 try: import triton # noqa @@ -47,6 +48,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): assert torch.allclose(embd_x0, embd_stimulated_x) # create data + block_size = 32 + max_num_blocks_per_seq = 4 q_shape = (TOTAL_TOKENS, H, D) q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") k_shape = (TOTAL_TOKENS, H, D) @@ -54,13 +57,35 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): cos_shape = (TOTAL_TOKENS, D // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + v_cache = torch.zeros_like(k_cache) + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - q_ref = torch_rotary_emb(q, cos, sin) - k_ref = torch_rotary_emb(k, cos, sin) - rotary_embedding(q, k, cos, sin) + rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths) + assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4) + assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4) - assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4) - assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4) + # check one by one + for seq_i in range(BATCH_SIZE): + ki = new_k[seq_i] + ki = ki.squeeze() + past_kv_seq_len = kv_seq_lengths[seq_i] - 1 + target_block_id = block_tables[seq_i, past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + target = k_cache[target_block_id, :, offsets_in_block, :] + orig = new_k[seq_i].squeeze(dim=0) + assert torch.equal(orig, target) BATCH = 16 diff --git a/tests/test_infer_ops/triton/test_xine_copy.py b/tests/test_infer_ops/triton/test_xine_copy.py index c19be5abe..efa7d74e5 100644 --- a/tests/test_infer_ops/triton/test_xine_copy.py +++ b/tests/test_infer_ops/triton/test_xine_copy.py @@ -53,10 +53,10 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): assert torch.allclose(cos, cos_ref) assert torch.allclose(sin, sin_ref) # decoding - ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False) assert torch.allclose(cos, ncos_ref) - assert torch.allclose(sin, sin_ref) + assert torch.allclose(sin, nsin_ref) configs = [