[Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added
pull/4509/head
Cuiqing Li 2023-08-24 16:30:02 +08:00 committed by GitHub
parent 222953a399
commit 7d7ea2ef41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 865 additions and 107 deletions

32
LICENSE
View File

@ -396,3 +396,35 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR VLLM TEAM ----------------
from VLLM TEAM:
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://github.com/vllm-project/vllm/blob/main/LICENSE
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---------------- LICENSE FOR LIGHTLLM TEAM ----------------
from LIGHTLLM TEAM:
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://github.com/ModelTC/lightllm/blob/main/LICENSE
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,184 @@
import torch
import math
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
'''
@triton.jit
def _context_flash_attention_kernel(
Q, K, V, sm_scale,
B_Start_Loc, B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_tmp_b, stride_tmp_h, stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if alibi_ptr is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
q, k, v, sm_scale,
b_start_loc, b_seq_len,
tmp,
alibi,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
_context_flash_attention_kernel[grid](
q, k, v, sm_scale, b_start_loc, b_seq_len,
tmp,
None,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return

View File

@ -0,0 +1,69 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
@triton.jit
def _fwd_copy_kv_cache_dest(
kv_cache_ptr, dest_index_ptr,
out,
stride_k_bs,
stride_k_h,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_d,
head_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
dest_index = tl.load(dest_index_ptr + cur_index)
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
o_ptrs = out + dest_index * stride_o_bs + o_offsets
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return
@torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0]
head_num = k_ptr.shape[1]
head_dim = k_ptr.shape[2]
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr, dest_index_ptr, out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr.stride(2),
out.stride(0),
out.stride(1),
out.stride(2),
head_num,
BLOCK_DMODEL=head_dim,
BLOCK_HEAD=triton.next_power_of_2(head_num),
num_warps=num_warps,
num_stages=2,
)
return

View File

@ -11,7 +11,7 @@ except ImportError:
if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax_kernel import softmax_kernel
from .softmax import softmax_kernel
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
@ -156,54 +156,3 @@ if HAS_TRITON:
q, k, v, input_mask, scale)
return data_output_triton
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
num_rows, num_cols = input.shape
block_size = max(triton.next_power_of_2(num_cols), 2)
num_warps = 16
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4
if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
else:
grid = lambda meta: ()
grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
)
BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
elif block_size >= 2048:
BLOCK_M = 8
softmax_kernel_2[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)
return output

View File

@ -0,0 +1,96 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None:
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * row_stride
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
num_rows, num_cols = input.shape
block_size = max(triton.next_power_of_2(num_cols), 2)
num_warps = 16
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4
if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
else:
grid = lambda meta: ()
grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
)
BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
elif block_size >= 2048:
BLOCK_M = 8
softmax_kernel[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)
return output

View File

@ -1,44 +0,0 @@
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None:
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * row_stride
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)

View File

@ -7,7 +7,7 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -472,9 +472,19 @@ class LlamaInferenceForwards:
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
HAS_VLLM_KERNERL = False
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
@ -496,6 +506,11 @@ def get_llama_flash_attention_forward():
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@ -531,3 +546,32 @@ def get_llama_flash_attention_forward():
return attn_output, None, past_key_value
return forward
def get_llama_vllm_rmsnorm_forward():
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
HAS_VLLM_KERNERL = False
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
return _vllm_rmsnorm_forward
else:
return None

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import pytest
import numpy as np
from packaging import version
import torch
from torch import nn
from torch.nn import functional as F
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rmsnorm():
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
hg_rms = LlamaRMSNorm(64)
hg_rms = hg_rms.half().cuda()
out_torch = hg_rms(data)
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
if __name__ == "__main__":
test_rmsnorm()

View File

@ -0,0 +1,156 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RefRotaryEmbeddingNeox(nn.Module):
"""Reference implementation of the GPT-NeoX style rotary embedding."""
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
def run_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
rotary_embedding_neox(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
def test_rotary_embedding():
run_rotary_embedding_neox(
num_tokens=1024,
num_heads=8,
head_size=64,
max_position=8192,
rotary_dim=64,
dtype=torch.float16,
)
if __name__ == "__main__":
test_rotary_embedding()

View File

@ -0,0 +1,57 @@
import pytest
import math
from packaging import version
import torch
from torch import nn
from torch.nn import functional as F
try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_bloom_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"
latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi)
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)
print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))
if __name__ == "__main__":
test_bloom_context_attention()

View File

@ -0,0 +1,41 @@
import pytest
from packaging import version
import torch
from torch import nn
try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_kv_cache_copy_op():
B_NTX = 32 * 2048
head_num = 8
head_dim = 64
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
copy_kv_cache_to_dest(cache, dest_index, dest_data)
assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data)
print("the average latency is {} ms".format(str(latency)))
if __name__ == "__main__":
test_kv_cache_copy_op()

View File

@ -0,0 +1,57 @@
import pytest
import math
from packaging import version
import torch
from torch import nn
from torch.nn import functional as F
try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_llama_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"
latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len)
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)
print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))
if __name__ == "__main__":
test_llama_context_attention()

View File

@ -4,12 +4,11 @@ import torch
from torch import nn
import torch.nn.functional as F
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@ -17,7 +16,7 @@ except ImportError:
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
scale = 1.2
@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv,
return res.view(batches, -1, d_model), score_output, softmax_output
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)

View File

@ -3,11 +3,19 @@ from packaging import version
import torch
from torch import nn
from colossalai.kernel.triton.ops import softmax
try:
import triton
import triton.language as tl
from colossalai.kernel.triton.softmax import softmax
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
def test_softmax_op():
data_samples = [
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),

View File

@ -0,0 +1,50 @@
import numpy as np
import math
import torch
from torch.nn import functional as F
def benchmark(func, *args):
starter, ender = torch.cuda.Event(
enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
for i in range(10):
func(*args)
timings = np.zeros((repetitions, 1))
with torch.no_grad():
for rep in range(repetitions):
starter.record()
func(*args)
ender.record()
# WAIT FOR GPU SYNC
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
return mean_syn
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
'''
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
'''
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
sm_scale = 1/math.sqrt(head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output