From df0aa49585d2dd19d7397dfbd3b5f136abac609b Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Wed, 31 Jan 2024 16:31:29 +0800
Subject: [PATCH] [Inference] Kernel Fusion, fused copy kv cache into rotary
 embedding (#5336)

* revise rotary embedding

* remove useless print

* adapt
---
 .../kernel/triton/no_pad_rotary_embedding.py  | 229 ++++++++++++++++--
 .../triton/test_rotary_embdding_unpad.py      |  35 ++-
 tests/test_infer_ops/triton/test_xine_copy.py |   4 +-
 3 files changed, 238 insertions(+), 30 deletions(-)

diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py
index 5c799897a..89bd40b40 100644
--- a/colossalai/kernel/triton/no_pad_rotary_embedding.py
+++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
 import torch
 import triton
 import triton.language as tl
@@ -126,12 +128,161 @@ def rotary_embedding_kernel(
     )
 
 
+@triton.jit
+def fused_rotary_embedding_kernel(
+    q,
+    k,
+    cos,
+    sin,
+    kv_cache,
+    BLOCK_TABLES,
+    context_lengths,
+    q_token_stride,
+    q_head_stride,
+    k_token_stride,
+    k_head_stride,
+    head_dim_stride,
+    cos_token_stride,
+    cos_stride,
+    cacheb_stride,
+    cacheh_stride,
+    cachebs_stride,
+    cached_stride,
+    bts_stride,
+    btb_stride,
+    block_size,
+    q_total_tokens,
+    Q_HEAD_NUM: tl.constexpr,
+    K_HEAD_NUM: tl.constexpr,
+    HEAD_DIM: tl.constexpr,
+    BLOCK_HEAD: tl.constexpr,
+    BLOCK_TOKENS: tl.constexpr,
+):
+    block_head_index = tl.program_id(0)
+    block_token_index = tl.program_id(1)
+
+    tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
+    head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
+
+    dim_range0 = tl.arange(0, HEAD_DIM // 2)
+    dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
+
+    off_q0 = (
+        tokens_range[:, None, None] * q_token_stride
+        + head_range[None, :, None] * q_head_stride
+        + dim_range0[None, None, :] * head_dim_stride
+    )
+    off_q1 = (
+        tokens_range[:, None, None] * q_token_stride
+        + head_range[None, :, None] * q_head_stride
+        + dim_range1[None, None, :] * head_dim_stride
+    )
+    off_k0 = (
+        tokens_range[:, None, None] * k_token_stride
+        + head_range[None, :, None] * k_head_stride
+        + dim_range0[None, None, :] * head_dim_stride
+    )
+    off_k1 = (
+        tokens_range[:, None, None] * k_token_stride
+        + head_range[None, :, None] * k_head_stride
+        + dim_range1[None, None, :] * head_dim_stride
+    )
+
+    loaded_q0 = tl.load(
+        q + off_q0,
+        mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+        other=0.0,
+    )
+    loaded_q1 = tl.load(
+        q + off_q1,
+        mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+        other=0.0,
+    )
+
+    loaded_k0 = tl.load(
+        k + off_k0,
+        mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+        other=0.0,
+    )
+
+    loaded_k1 = tl.load(
+        k + off_k1,
+        mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+        other=0.0,
+    )
+
+    off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
+
+    loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
+    loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
+
+    out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
+    out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
+
+    out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
+    out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]  # total_tokens, head_num, head_dim
+
+    past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
+
+    last_block_idx = past_kv_seq_len // block_size
+    block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
+    block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
+    offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
+
+    kv_range0 = (
+        block_ids[:, None, None, None] * cacheb_stride
+        + head_range[None, :, None, None] * cacheh_stride
+        + offsets_in_last_block[:, None, None, None]
+        + dim_range0[None, None, None, :] * cached_stride
+    )
+    kv_range1 = (
+        block_ids[:, None, None, None] * cacheb_stride
+        + head_range[None, :, None, None] * cacheh_stride
+        + offsets_in_last_block[:, None, None, None]
+        + dim_range1[None, None, None, :] * cached_stride
+    )
+
+    tl.store(
+        kv_cache + kv_range0,
+        out_k0[:, :, None, :],
+    )
+    tl.store(
+        kv_cache + kv_range1,
+        out_k1[:, :, None, :],
+    )
+
+    # concat
+    tl.store(
+        q + off_q0,
+        out_q0,
+        mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+    )
+    tl.store(
+        q + off_q1,
+        out_q1,
+        mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+    )
+    tl.store(
+        k + off_k0,
+        out_k0,
+        mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+    )
+    tl.store(
+        k + off_k1,
+        out_k1,
+        mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
+    )
+
+
 @torch.no_grad()
 def rotary_embedding(
     q: torch.Tensor,
     k: torch.Tensor,
     cos: torch.Tensor,
     sin: torch.Tensor,
+    k_cache: Optional[torch.Tensor] = None,
+    block_tables: Optional[torch.Tensor] = None,
+    kv_lengths: Optional[torch.Tensor] = None,
 ):
     """
     Args:
@@ -139,7 +290,9 @@ def rotary_embedding(
         k: key tensor, [total_tokens, head_num, head_dim]
         cos: cosine for rotary embedding, [max_position_len, head_dim]
         sin: sine for rotary embedding, [max_position_len, head_dim]
-        lengths [num_seqs]
+        k_cache (torch.Tensor):  Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
+        kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
+        block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
     """
     q_total_tokens, q_head_num, head_dim = q.shape
     assert q.size(0) == k.size(0)
@@ -165,26 +318,56 @@ def rotary_embedding(
 
     cos_token_stride = cos.stride(0)
     cos_stride = cos.stride(1)
-
-    rotary_embedding_kernel[grid](
-        q,
-        k,
-        cos,
-        sin,
-        q_token_stride,
-        q_head_stride,
-        k_token_stride,
-        k_head_stride,
-        head_dim_stride,
-        cos_token_stride,
-        cos_stride,
-        q_total_tokens,
-        Q_HEAD_NUM=q_head_num,
-        K_HEAD_NUM=k_head_num,
-        HEAD_DIM=head_dim,
-        BLOCK_HEAD=BLOCK_HEAD,
-        BLOCK_TOKENS=BLOCK_TOKENS,
-        num_warps=num_warps,
-    )
-
+    if k_cache == None:
+        rotary_embedding_kernel[grid](
+            q,
+            k,
+            cos,
+            sin,
+            q_token_stride,
+            q_head_stride,
+            k_token_stride,
+            k_head_stride,
+            head_dim_stride,
+            cos_token_stride,
+            cos_stride,
+            q_total_tokens,
+            Q_HEAD_NUM=q_head_num,
+            K_HEAD_NUM=k_head_num,
+            HEAD_DIM=head_dim,
+            BLOCK_HEAD=BLOCK_HEAD,
+            BLOCK_TOKENS=BLOCK_TOKENS,
+            num_warps=num_warps,
+        )
+    else:
+        fused_rotary_embedding_kernel[grid](
+            q,
+            k,
+            cos,
+            sin,
+            k_cache,
+            block_tables,
+            kv_lengths,
+            q_token_stride,
+            q_head_stride,
+            k_token_stride,
+            k_head_stride,
+            head_dim_stride,
+            cos_token_stride,
+            cos_stride,
+            k_cache.stride(0),
+            k_cache.stride(1),
+            k_cache.stride(2),
+            k_cache.stride(3),
+            block_tables.stride(0),
+            block_tables.stride(1),
+            k_cache.size(-2),
+            q_total_tokens,
+            Q_HEAD_NUM=q_head_num,
+            K_HEAD_NUM=k_head_num,
+            HEAD_DIM=head_dim,
+            BLOCK_HEAD=BLOCK_HEAD,
+            BLOCK_TOKENS=BLOCK_TOKENS,
+            num_warps=num_warps,
+        )
     return
diff --git a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py
index d611234f0..529c9fb2f 100644
--- a/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py
+++ b/tests/test_infer_ops/triton/test_rotary_embdding_unpad.py
@@ -4,6 +4,7 @@ from packaging import version
 from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
 
 from colossalai.kernel.triton import rotary_embedding
+from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
 
 try:
     import triton  # noqa
@@ -47,6 +48,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
     assert torch.allclose(embd_x0, embd_stimulated_x)
 
     # create data
+    block_size = 32
+    max_num_blocks_per_seq = 4
     q_shape = (TOTAL_TOKENS, H, D)
     q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
     k_shape = (TOTAL_TOKENS, H, D)
@@ -54,13 +57,35 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
     cos_shape = (TOTAL_TOKENS, D // 2)
     cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
     sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
+    cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
+    k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
+    v = torch.randn_like(k)
+    v_cache = torch.zeros_like(k_cache)
+    past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
+    block_tables = mock_alloc_block_table_and_kvcache_v2(
+        k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
+    )
+    new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
+    new_q = torch.randn_like(new_k)
+    kv_seq_lengths = past_kv_seq_lengths + 1
+    block_tables = block_tables.to(device="cuda")
+    q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
+    k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
 
-    q_ref = torch_rotary_emb(q, cos, sin)
-    k_ref = torch_rotary_emb(k, cos, sin)
-    rotary_embedding(q, k, cos, sin)
+    rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
+    assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
+    assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4)
 
-    assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
-    assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
+    # check one by one
+    for seq_i in range(BATCH_SIZE):
+        ki = new_k[seq_i]
+        ki = ki.squeeze()
+        past_kv_seq_len = kv_seq_lengths[seq_i] - 1
+        target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
+        offsets_in_block = past_kv_seq_len % block_size
+        target = k_cache[target_block_id, :, offsets_in_block, :]
+        orig = new_k[seq_i].squeeze(dim=0)
+        assert torch.equal(orig, target)
 
 
 BATCH = 16
diff --git a/tests/test_infer_ops/triton/test_xine_copy.py b/tests/test_infer_ops/triton/test_xine_copy.py
index c19be5abe..efa7d74e5 100644
--- a/tests/test_infer_ops/triton/test_xine_copy.py
+++ b/tests/test_infer_ops/triton/test_xine_copy.py
@@ -53,10 +53,10 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
     assert torch.allclose(cos, cos_ref)
     assert torch.allclose(sin, sin_ref)
     # decoding
-    ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
+    ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
     cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
     assert torch.allclose(cos, ncos_ref)
-    assert torch.allclose(sin, sin_ref)
+    assert torch.allclose(sin, nsin_ref)
 
 
 configs = [