diff --git a/LICENSE b/LICENSE index c7a5bb168..06629068f 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,35 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR VLLM TEAM ---------------- + + from VLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/vllm-project/vllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ---------------- LICENSE FOR LIGHTLLM TEAM ---------------- + + from LIGHTLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/ModelTC/lightllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py new file mode 100644 index 000000000..38db2048c --- /dev/null +++ b/colossalai/kernel/triton/context_attention.py @@ -0,0 +1,184 @@ +import torch +import math +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + ''' + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + ''' + @triton.jit + def _context_flash_attention_kernel( + Q, K, V, sm_scale, + B_Start_Loc, B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_tmp_b, stride_tmp_h, stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + + @torch.no_grad() + def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + num_warps = 4 if Lk <= 64 else 8 + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + + _context_flash_attention_kernel[grid]( + q, k, v, sm_scale, + b_start_loc, b_seq_len, + tmp, + alibi, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + tmp.stride(0), tmp.stride(1), tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _context_flash_attention_kernel[grid]( + q, k, v, sm_scale, b_start_loc, b_seq_len, + tmp, + None, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + tmp.stride(0), tmp.stride(1), tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py new file mode 100644 index 000000000..c1eaa8a10 --- /dev/null +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -0,0 +1,69 @@ +import torch + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + @triton.jit + def _fwd_copy_kv_cache_dest( + kv_cache_ptr, dest_index_ptr, + out, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr + ): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(dest_index_ptr + cur_index) + + cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] + k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets + + o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + o_ptrs = out + dest_index * stride_o_bs + o_offsets + + k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) + return + + + @torch.no_grad() + def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): + seq_len = dest_index_ptr.shape[0] + head_num = k_ptr.shape[1] + head_dim = k_ptr.shape[2] + assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" + assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" + + num_warps = 2 + + _fwd_copy_kv_cache_dest[(seq_len,)]( + k_ptr, dest_index_ptr, out, + k_ptr.stride(0), + k_ptr.stride(1), + k_ptr.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + head_num, + BLOCK_DMODEL=head_dim, + BLOCK_HEAD=triton.next_power_of_2(head_num), + num_warps=num_warps, + num_stages=2, + ) + return + + diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/self_attention_nofusion.py similarity index 74% rename from colossalai/kernel/triton/ops.py rename to colossalai/kernel/triton/self_attention_nofusion.py index 5e8d4ba3e..a6c9bdfbd 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -11,7 +11,7 @@ except ImportError: if HAS_TRITON: from .qkv_matmul_kernel import qkv_gemm_4d_kernel - from .softmax_kernel import softmax_kernel + from .softmax import softmax_kernel def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels @@ -155,55 +155,4 @@ if HAS_TRITON: data_output_triton = self_attention_forward_without_fusion( q, k, v, input_mask, scale) - return data_output_triton - - - def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: - if mask is not None: - assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" - assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" - - hidden_dim = input.shape[-1] - output = torch.empty_like(input) - input = input.view(-1, hidden_dim) - if mask is not None: - mask = mask.view(-1, hidden_dim) - assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" - - num_rows, num_cols = input.shape - block_size = max(triton.next_power_of_2(num_cols), 2) - num_warps = 16 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: - num_warps = 4 - - if num_rows <= 350000: - grid = (num_rows,) - softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) - else: - grid = lambda meta: () - - grid = lambda meta: ( - triton.cdiv(num_rows, meta["BLOCK_M"]), - ) - - BLOCK_M = 32 - if block_size >= 4096: - BLOCK_M = 4 - elif block_size >= 2048: - BLOCK_M = 8 - - softmax_kernel_2[grid](output_ptr = output, - input_ptr = input, - row_stride = input.stride(0), - n_rows = num_rows, - n_cols = num_cols, - mask_ptr = mask, - # currently manually setting up size - BLOCK_M = 32, - BLOCK_SIZE = block_size) - - return output \ No newline at end of file + return data_output_triton \ No newline at end of file diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py new file mode 100644 index 000000000..c65adaf40 --- /dev/null +++ b/colossalai/kernel/triton/softmax.py @@ -0,0 +1,96 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + ''' + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + ''' + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py deleted file mode 100644 index c215890ba..000000000 --- a/colossalai/kernel/triton/softmax_kernel.py +++ /dev/null @@ -1,44 +0,0 @@ -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - ''' - softmax kernel is modified based on - https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py - ''' - @triton.jit - def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): - r""" the kernel function for implementing softmax operator - Args: - output_ptr: the output after finishing softmax operation, (N, hidden_dim) - input_ptr: the tensor of input, shape should be (N, hidden_dim) - n_cols(tl.constexpr): the number of cols of input - BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim - """ - row_idx = tl.program_id(0) - row_start_ptr = input_ptr + row_idx * row_stride - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) - row_minus_max = row - tl.max(row, axis=0) - - if mask_ptr is not None: - # load mask into SRAM - mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets - mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) - - # update - row_minus_max = row_minus_max + mask - - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - output_row_start_ptr = output_ptr + row_idx * row_stride - output_ptrs = output_row_start_ptr + col_offsets - # Write back output to DRAM - tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 57afce70e..a18d700f9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,7 +7,7 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -472,8 +472,18 @@ class LlamaInferenceForwards: def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True + except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") + HAS_VLLM_KERNERL = False + def forward( self: LlamaAttention, @@ -496,7 +506,12 @@ def get_llama_flash_attention_forward(): kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -531,3 +546,32 @@ def get_llama_flash_attention_forward(): return attn_output, None, past_key_value return forward + + +def get_llama_vllm_rmsnorm_forward(): + try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True + except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") + HAS_VLLM_KERNERL = False + + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py new file mode 100644 index 000000000..cb12faf62 --- /dev/null +++ b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import os +import pytest +import numpy as np +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + +try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True +except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rmsnorm(): + data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") + hg_rms = LlamaRMSNorm(64) + hg_rms = hg_rms.half().cuda() + out_torch = hg_rms(data) + out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) + + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + +if __name__ == "__main__": + test_rmsnorm() \ No newline at end of file diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py new file mode 100644 index 000000000..2a85566c6 --- /dev/null +++ b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import pytest +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half + +try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbeddingNeox(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + +def run_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') + query = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + + # Create the rotary embedding. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + rotary_embedding_neox( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=rotary_dim, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device='cuda') + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rotary_embedding(): + run_rotary_embedding_neox( + num_tokens=1024, + num_heads=8, + head_size=64, + max_position=8192, + rotary_dim=64, + dtype=torch.float16, + ) + +if __name__ == "__main__": + test_rotary_embedding() \ No newline at end of file diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py new file mode 100644 index 000000000..6c10ee3ff --- /dev/null +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -0,0 +1,57 @@ +import pytest +import math +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + +try: + import triton + import triton.language as tl + from tests.test_kernels.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +def test_bloom_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + + max_input_len = seq_len + b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") + bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) + + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi) + latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) + + print("the triton op latency is {} ms".format(str(latency_1))) + print("the torch op latency is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test_bloom_context_attention() \ No newline at end of file diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py new file mode 100644 index 000000000..068295a0e --- /dev/null +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -0,0 +1,41 @@ +import pytest +from packaging import version + +import torch +from torch import nn + +try: + import triton + import triton.language as tl + from tests.test_kernels.triton.utils import benchmark + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +def test_kv_cache_copy_op(): + + B_NTX = 32 * 2048 + head_num = 8 + head_dim = 64 + + cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) + + dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + + copy_kv_cache_to_dest(cache, dest_index, dest_data) + + assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + + latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data) + print("the average latency is {} ms".format(str(latency))) + + +if __name__ == "__main__": + test_kv_cache_copy_op() + diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py new file mode 100644 index 000000000..04d081408 --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -0,0 +1,57 @@ +import pytest +import math +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + +try: + import triton + import triton.language as tl + from tests.test_kernels.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton.context_attention import llama_context_attn_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +def test_llama_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + + max_input_len = seq_len + b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) + + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) + latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) + + print("the triton op latency is {} ms".format(str(latency_1))) + print("the torch op latency is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test_llama_context_attention() \ No newline at end of file diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py similarity index 91% rename from tests/test_kernels/test_self_attention.py rename to tests/test_infer_ops/triton/test_self_attention_nonfusion.py index b316404a5..9692737a0 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py @@ -4,12 +4,11 @@ import torch from torch import nn import torch.nn.functional as F -from colossalai.kernel.triton.ops import self_attention_compute_using_triton -from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel - try: import triton import triton.language as tl + from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -17,7 +16,7 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_qkv_matmul(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 @@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv, return res.view(batches, -1, d_model), score_output, softmax_output -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_self_atttention_test(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) diff --git a/tests/test_kernels/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py similarity index 68% rename from tests/test_kernels/test_softmax.py rename to tests/test_infer_ops/triton/test_softmax.py index 843d811d0..6a244608c 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_infer_ops/triton/test_softmax.py @@ -3,11 +3,19 @@ from packaging import version import torch from torch import nn -from colossalai.kernel.triton.ops import softmax + +try: + import triton + import triton.language as tl + from colossalai.kernel.triton.softmax import softmax + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_softmax_op(): data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), diff --git a/tests/test_infer_ops/triton/utils.py b/tests/test_infer_ops/triton/utils.py new file mode 100644 index 000000000..940d277cf --- /dev/null +++ b/tests/test_infer_ops/triton/utils.py @@ -0,0 +1,50 @@ +import numpy as np +import math + +import torch +from torch.nn import functional as F + + +def benchmark(func, *args): + starter, ender = torch.cuda.Event( + enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = 300 + + for i in range(10): + func(*args) + + timings = np.zeros((repetitions, 1)) + with torch.no_grad(): + for rep in range(repetitions): + starter.record() + func(*args) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[rep] = curr_time + + mean_syn = np.sum(timings) / repetitions + return mean_syn + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + ''' + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + ''' + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + sm_scale = 1/math.sqrt(head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale + scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) + + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output \ No newline at end of file