From c647e00e3c092d3d6219f7686f260f2932a0c27d Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:20:42 +0800 Subject: [PATCH] [Inference]Add fused rotary kernel and get cos cache kernel (#5302) * add fused rotary and get cos cache func * staged * fix bugs * fix bugs --- colossalai/kernel/triton/__init__.py | 7 +- .../kernel/triton/fused_rotary_embedding.py | 182 ++++++++++++++++++ .../kernel/triton/no_pad_rotary_embedding.py | 7 +- colossalai/kernel/triton/rotary_cache_copy.py | 110 +++++++++++ .../triton/test_fused_rotary_embedding.py | 93 +++++++++ tests/test_infer_ops/triton/test_xine_copy.py | 83 ++++++++ 6 files changed, 477 insertions(+), 5 deletions(-) create mode 100644 colossalai/kernel/triton/fused_rotary_embedding.py create mode 100644 colossalai/kernel/triton/rotary_cache_copy.py create mode 100644 tests/test_infer_ops/triton/test_fused_rotary_embedding.py create mode 100644 tests/test_infer_ops/triton/test_xine_copy.py diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index b814b142b..fb8b3339b 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -11,11 +11,12 @@ if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded from .flash_decoding import flash_decoding_attention from .flash_decoding_utils import FDIntermTensors - - from .rms_layernorm import rms_layernorm + from .fused_rotary_embedding import fused_rotary_embedding from .gptq_triton import gptq_fused_linear_triton from .kvcache_copy import copy_kv_to_blocked_cache from .no_pad_rotary_embedding import rotary_embedding + from .rms_layernorm import rms_layernorm + from .rotary_cache_copy import get_xine_cache from .softmax import softmax __all__ = [ @@ -27,4 +28,6 @@ if HAS_TRITON: "gptq_fused_linear_triton", "rotary_embedding", "FDIntermTensors", + "fused_rotary_embedding", + "get_xine_cache", ] diff --git a/colossalai/kernel/triton/fused_rotary_embedding.py b/colossalai/kernel/triton/fused_rotary_embedding.py new file mode 100644 index 000000000..133aa4adb --- /dev/null +++ b/colossalai/kernel/triton/fused_rotary_embedding.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_rotary_emb( + q, + k, + cos_cache, + sin_cache, + cumsum_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_dim_stride, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + K_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_ELEMENTS: tl.constexpr, +): + block_head_index = tl.program_id(0) + block_group_index = tl.program_id(1) + group_token_index = tl.program_id(2) + idx = block_group_index * BLOCK_SIZE + group_token_index + + # original seq_idx and pos + cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) + ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) + cos = tl.load( + cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride + ) # [1,HEAD_DIM//2] + sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride) + + cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * q_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_q1 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + off_k0 = ( + idx * k_token_stride + + cur_head_range[None, :, None] * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + q_0 = tl.load( + q + off_q0, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + q_1 = tl.load( + q + off_q1, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + k_0 = tl.load( + k + off_k0, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + k_1 = tl.load( + k + off_k1, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + out_q0 = q_0 * cos - q_1 * sin + out_q1 = k_0 * sin + k_1 * cos + + out_k0 = q_0 * cos - q_1 * sin + out_k1 = k_0 * sin + k_1 * cos + # concat + tl.store( + q + off_q0, + out_q0, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + ) + tl.store( + q + off_q1, + out_q1, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + ) + + tl.store( + k + off_k0, + out_k0, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + ) + tl.store( + k + off_k1, + out_k1, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + ) + + +@torch.no_grad() +def fused_rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + lengths, +): + """ + Args: + q: query tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, head_num, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + lengths [num_seqs] + """ + q_total_tokens, q_head_num, head_dim = q.shape + assert q.size(0) == k.size(0) + BLOCK_HEAD = 4 + BLOCK_SIZE = 16 + cumsum_lens = torch.cumsum(lengths, dim=0) + + grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE) + + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + q_token_stride = q.stride(0) + q_head_stride = q.stride(1) + head_dim_stride = q.stride(2) + + k_token_stride = k.stride(0) + k_head_stride = k.stride(1) + + k_head_num = q.shape[1] + + cos_token_stride = cos.stride(0) + cos_dim_stride = cos.stride(1) + + fused_rotary_emb[grid]( + q, + k, + cos, + sin, + cumsum_lens, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_dim_stride, + q_total_tokens, + Q_HEAD_NUM=q_head_num, + K_HEAD_NUM=k_head_num, + HEAD_DIM=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SIZE=BLOCK_SIZE, + N_ELEMENTS=triton.next_power_of_2(q_total_tokens), + num_warps=num_warps, + ) diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index e4bab18eb..40ac6b53b 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -98,11 +98,12 @@ def rotary_embedding( Args: q: query tensor, [total_tokens, head_num, head_dim] k: key tensor, [total_tokens, head_num, head_dim] - cos: cosine for rotary embedding, [total_tokens, head_dim] - sin: sine for rotary embedding, [total_tokens, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + lengths [num_seqs] """ q_total_tokens, q_head_num, head_dim = q.shape - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + assert q.size(0) == k.size(0) BLOCK_HEAD = 4 BLOCK_TOKENS = 8 grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS)) diff --git a/colossalai/kernel/triton/rotary_cache_copy.py b/colossalai/kernel/triton/rotary_cache_copy.py new file mode 100644 index 000000000..771dedac5 --- /dev/null +++ b/colossalai/kernel/triton/rotary_cache_copy.py @@ -0,0 +1,110 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def prefill_cache_kernel( + CaChe, + cumsum_lengths, + output, + cache_stride, + hidden_stride, + total_length, + HIDDEN_DIM: tl.constexpr, + N_ELEMENTS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + idx0 = tl.program_id(axis=0) + idx1 = tl.program_id(axis=1) + idx = idx0 * BLOCK_SIZE + idx1 + + # original seq_idx and pos + cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) + ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) + _cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride) + tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length) + + +@triton.jit +def decoding_cache_kernel( + CaChe, + lengths, + output, + cache_stride, + hidden_stride, + HIDDEN_DIM: tl.constexpr, + NUM_SEQS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,] + _cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride) + tl.store( + output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), + _cache, + mask=idx[:, None] < NUM_SEQS, + ) + + +@torch.no_grad() +def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False): + """ + Transform cos/sin cache into no pad sequence, with two different modes. + Args: + lengths: shape(num_seqs,), stores lenghth of each sequence. + cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model. + is_prompts: bool, mark if in prefill mode. + For prefill mode: + cos/sin cache for each sequence is equal to its length. + For decoding mode: + cos/sin cache is only needed for the last token. + """ + + _, hidden_dim = cache.shape + num_seqs = lengths.numel() + + BLOCK_SIZE = 16 + if hidden_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + cache_stride = cache.stride(0) + hidden_stride = cache.stride(1) + + if is_prompts: + total_length = lengths.sum().item() + cumsum_lens = torch.cumsum(lengths, dim=0) + output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device) + grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE) + prefill_cache_kernel[grid]( + cache, + cumsum_lens, + output, + cache_stride, + hidden_stride, + total_length, + HIDDEN_DIM=hidden_dim, + N_ELEMENTS=triton.next_power_of_2(num_seqs), + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + else: + # BUG: get memory access error whe using a deepcopy lengths to replace lengths + nlengths = torch.as_tensor(lengths) - 1 + output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device) + grid = (triton.cdiv(num_seqs, BLOCK_SIZE),) + decoding_cache_kernel[grid]( + cache, + nlengths, + output, + cache_stride, + hidden_stride, + HIDDEN_DIM=hidden_dim, + NUM_SEQS=num_seqs, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return output diff --git a/tests/test_infer_ops/triton/test_fused_rotary_embedding.py b/tests/test_infer_ops/triton/test_fused_rotary_embedding.py new file mode 100644 index 000000000..658bc872f --- /dev/null +++ b/tests/test_infer_ops/triton/test_fused_rotary_embedding.py @@ -0,0 +1,93 @@ +from copy import deepcopy + +import torch +import triton + +from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding +from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding +from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache + +BATCH = 16 +configs = [ + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[2**i for i in range(4, 12)], + line_arg="provider", + line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], + line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"rotary_emb-batch-{BATCH}", + args={"num_kv_heads": 16}, + ) +] + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@triton.testing.perf_report(configs) +def benchmark_rotary_emb( + provider: str, + num_tokens: int, + num_kv_heads: int, +): + warmup = 10 + rep = 100 + + head_dim = 128 + dtype = torch.float16 + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (4096, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + if provider == "torch_rotary_emb_func": + fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens]) + elif provider == "triton_rotary_emb_func": + fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + num_tokens = 20 + num_kv_heads = 32 + head_dim = 64 + dtype = torch.float32 + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + q_copy = deepcopy(q) + + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + k_copy = deepcopy(k) + + cos_shape = (1024, head_dim) + lengths = torch.tensor([3, 4, 6, 7], device="cuda") + cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + cos = get_xine_cache(lengths, cos_cache[:, : head_dim // 2]) + sin = get_xine_cache(lengths, sin_cache[:, : head_dim // 2]) + + rotary_embedding(q, k, cos, sin) + fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths) + torch.allclose(q, q_copy) + torch.allclose(k, k_copy) + + # benchmark_rotary_emb.run(save_path=".",print_data=True) diff --git a/tests/test_infer_ops/triton/test_xine_copy.py b/tests/test_infer_ops/triton/test_xine_copy.py new file mode 100644 index 000000000..0e63a7012 --- /dev/null +++ b/tests/test_infer_ops/triton/test_xine_copy.py @@ -0,0 +1,83 @@ +import pytest +import torch +from packaging import version + +from colossalai.inference.modeling.models.llama import get_cos_sin +from colossalai.kernel.triton import get_xine_cache + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [64]) +@pytest.mark.parametrize("HEAD_DIM", [64]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): + MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN + cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda") + # prefill + cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=True, dtype=dtype) + cos = get_xine_cache(lengths, cos_cache, is_prompts=True) + assert torch.allclose(cos, cos_ref) + # decoding + ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=False, dtype=dtype) + cos = get_xine_cache(lengths, cos_cache, is_prompts=False) + assert torch.allclose(cos, ncos_ref) + + +configs = [ + triton.testing.Benchmark( + x_names=["max_num_tokens"], + x_vals=[2**i for i in range(6, 12)], + line_arg="provider", + line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"], + line_names=["torch_get_cos_sin_func", "triton_get_xine_func"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name="Get_cos-sin_func", + args={"batch_size": 16, "head_dim": 256}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_get_xine_cache( + provider: str, + max_num_tokens: int, + batch_size: int, + head_dim: int, +): + warmup = 10 + rep = 1000 + max_token_per_seq = max_num_tokens // batch_size + dtype = torch.float16 + cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda") + sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda") + lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda") + + if provider == "torch_get_cos_sin_func": + fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + elif provider == "triton_get_xine_func": + fn = lambda: [ + get_xine_cache(lengths, cos_cache, is_prompts=False), + get_xine_cache(lengths, sin_cache, is_prompts=False), + ] + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + test_get_xine_cache(4, 64, 256, torch.float32) + # benchmark_get_xine_cache.run(save_path=".",print_data=True)