diff --git a/colossalai/kernel/triton/fused_rotary_embedding.py b/colossalai/kernel/triton/fused_rotary_embedding.py index 133aa4adb..237b088a4 100644 --- a/colossalai/kernel/triton/fused_rotary_embedding.py +++ b/colossalai/kernel/triton/fused_rotary_embedding.py @@ -136,7 +136,7 @@ def fused_rotary_embedding( q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) BLOCK_HEAD = 4 - BLOCK_SIZE = 16 + BLOCK_SIZE = 8 cumsum_lens = torch.cumsum(lengths, dim=0) grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE) diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 40ac6b53b..5c799897a 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -2,6 +2,22 @@ import torch import triton import triton.language as tl +""" +# Base autotune if needed +@triton.autotune( + configs=[ + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=4), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":8,},num_warps=8), + triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":8,},num_warps=8), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=16), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=32), + triton.Config({'BLOCK_HEAD':16,"BLOCK_TOKENS":16,},num_warps=4), + triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":16,},num_warps=8), + ], + key=['HEAD_DIM','q_total_tokens','Q_HEAD_NUM'] +) +""" + @triton.jit def rotary_embedding_kernel( @@ -26,43 +42,53 @@ def rotary_embedding_kernel( block_head_index = tl.program_id(0) block_token_index = tl.program_id(1) - rotary_data = q - HEAD_NUM = Q_HEAD_NUM - head_stride = q_head_stride - token_stride = q_token_stride - - if block_token_index * BLOCK_TOKENS >= q_total_tokens: - block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS) - rotary_data = k - HEAD_NUM = K_HEAD_NUM - head_stride = k_head_stride - token_stride = k_token_stride - tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - off_data0 = ( - tokens_range[:, None, None] * token_stride - + head_range[None, :, None] * head_stride + off_q0 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + dim_range0[None, None, :] * head_dim_stride ) - off_data1 = ( - tokens_range[:, None, None] * token_stride - + head_range[None, :, None] * head_stride + off_q1 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + off_k0 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + dim_range1[None, None, :] * head_dim_stride ) - loaded_data0 = tl.load( - rotary_data + off_data0, - mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + loaded_q0 = tl.load( + q + off_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), other=0.0, ) - loaded_data1 = tl.load( - rotary_data + off_data1, - mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + loaded_q1 = tl.load( + q + off_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k0 = tl.load( + k + off_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k1 = tl.load( + k + off_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), other=0.0, ) @@ -71,19 +97,32 @@ def rotary_embedding_kernel( loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) - out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :] - out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :] + out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] + out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] + + out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] + out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # concat tl.store( - rotary_data + off_data0, - out0, - mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + q + off_q0, + out_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), ) tl.store( - rotary_data + off_data1, - out1, - mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + q + off_q1, + out_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k0, + out_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k1, + out_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), ) @@ -105,11 +144,13 @@ def rotary_embedding( q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) BLOCK_HEAD = 4 - BLOCK_TOKENS = 8 - grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS)) + BLOCK_TOKENS = 4 + grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"])) - if head_dim >= 128: - num_warps = 8 + if head_dim >= 256: + num_warps = 32 + elif head_dim >= 128: + num_warps = 16 else: num_warps = 4 @@ -144,7 +185,6 @@ def rotary_embedding( BLOCK_HEAD=BLOCK_HEAD, BLOCK_TOKENS=BLOCK_TOKENS, num_warps=num_warps, - num_stages=1, ) return diff --git a/colossalai/kernel/triton/rotary_cache_copy.py b/colossalai/kernel/triton/rotary_cache_copy.py index 771dedac5..6b064ed4a 100644 --- a/colossalai/kernel/triton/rotary_cache_copy.py +++ b/colossalai/kernel/triton/rotary_cache_copy.py @@ -5,9 +5,11 @@ import triton.language as tl @triton.jit def prefill_cache_kernel( - CaChe, + cos_cache, + sin_cache, cumsum_lengths, - output, + cos_output, + sin_output, cache_stride, hidden_stride, total_length, @@ -22,15 +24,31 @@ def prefill_cache_kernel( # original seq_idx and pos cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) - _cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride) - tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length) + cos_cache_part = tl.load( + cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length + ) + sin_cache_part = tl.load( + sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length + ) + tl.store( + cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, + cos_cache_part, + mask=idx < total_length, + ) + tl.store( + sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, + sin_cache_part, + mask=idx < total_length, + ) @triton.jit def decoding_cache_kernel( - CaChe, + cos_cache, + sin_cache, lengths, - output, + cos_output, + sin_output, cache_stride, hidden_stride, HIDDEN_DIM: tl.constexpr, @@ -39,16 +57,28 @@ def decoding_cache_kernel( ): idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,] - _cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride) + cos_cache_part = tl.load( + cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, + mask=idx[:, None] < NUM_SEQS, + ) + sin_cache_part = tl.load( + sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, + mask=idx[:, None] < NUM_SEQS, + ) tl.store( - output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), - _cache, + cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), + cos_cache_part, + mask=idx[:, None] < NUM_SEQS, + ) + tl.store( + sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), + sin_cache_part, mask=idx[:, None] < NUM_SEQS, ) @torch.no_grad() -def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False): +def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False): """ Transform cos/sin cache into no pad sequence, with two different modes. Args: @@ -60,28 +90,33 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool For decoding mode: cos/sin cache is only needed for the last token. """ - - _, hidden_dim = cache.shape + assert cos_cache.shape[1] == sin_cache.shape[1] + _, hidden_dim = cos_cache.shape num_seqs = lengths.numel() - BLOCK_SIZE = 16 - if hidden_dim >= 128: + if hidden_dim >= 256: + num_warps = 16 + elif hidden_dim >= 128: num_warps = 8 else: num_warps = 4 - cache_stride = cache.stride(0) - hidden_stride = cache.stride(1) + cache_stride = cos_cache.stride(0) + hidden_stride = cos_cache.stride(1) if is_prompts: + BLOCK_SIZE = 16 total_length = lengths.sum().item() cumsum_lens = torch.cumsum(lengths, dim=0) - output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device) + cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) + sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE) prefill_cache_kernel[grid]( - cache, + cos_cache, + sin_cache, cumsum_lens, - output, + cos_output, + sin_output, cache_stride, hidden_stride, total_length, @@ -91,14 +126,17 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool num_warps=num_warps, ) else: - # BUG: get memory access error whe using a deepcopy lengths to replace lengths + BLOCK_SIZE = 4 nlengths = torch.as_tensor(lengths) - 1 - output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device) + cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) + sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) grid = (triton.cdiv(num_seqs, BLOCK_SIZE),) decoding_cache_kernel[grid]( - cache, + cos_cache, + sin_cache, nlengths, - output, + cos_output, + sin_output, cache_stride, hidden_stride, HIDDEN_DIM=hidden_dim, @@ -107,4 +145,4 @@ def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool num_warps=num_warps, ) - return output + return cos_output, sin_output diff --git a/tests/test_infer_ops/triton/test_xine_copy.py b/tests/test_infer_ops/triton/test_xine_copy.py index 0e63a7012..da2720659 100644 --- a/tests/test_infer_ops/triton/test_xine_copy.py +++ b/tests/test_infer_ops/triton/test_xine_copy.py @@ -39,8 +39,8 @@ configs = [ x_names=["max_num_tokens"], x_vals=[2**i for i in range(6, 12)], line_arg="provider", - line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"], - line_names=["torch_get_cos_sin_func", "triton_get_xine_func"], + line_vals=["torch_get_cos_sin", "triton_get_cos_sin"], + line_names=["torch_get_cos_sin", "triton_get_cos_sin"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name="Get_cos-sin_func", @@ -58,19 +58,15 @@ def benchmark_get_xine_cache( ): warmup = 10 rep = 1000 - max_token_per_seq = max_num_tokens // batch_size dtype = torch.float16 - cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda") - sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda") - lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda") + cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda") - if provider == "torch_get_cos_sin_func": - fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) - elif provider == "triton_get_xine_func": - fn = lambda: [ - get_xine_cache(lengths, cos_cache, is_prompts=False), - get_xine_cache(lengths, sin_cache, is_prompts=False), - ] + if provider == "torch_get_cos_sin": + fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + elif provider == "triton_get_cos_sin": + fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) else: raise ValueError("Undefined provider")