mirror of https://github.com/hpcaitech/ColossalAI
[Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)
* added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * addedpull/4509/head
parent
222953a399
commit
7d7ea2ef41
32
LICENSE
32
LICENSE
|
@ -396,3 +396,35 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
---------------- LICENSE FOR VLLM TEAM ----------------
|
||||
|
||||
from VLLM TEAM:
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://github.com/vllm-project/vllm/blob/main/LICENSE
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
---------------- LICENSE FOR LIGHTLLM TEAM ----------------
|
||||
|
||||
from LIGHTLLM TEAM:
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://github.com/ModelTC/lightllm/blob/main/LICENSE
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
import torch
|
||||
import math
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
'''
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
B_Start_Loc, B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
stride_tmp_b, stride_tmp_h, stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
b_start_loc, b_seq_len,
|
||||
tmp,
|
||||
alibi,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
# manually setting this blcok num, we can use tuning config to futher speed-up
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# num_warps = 4
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale, b_start_loc, b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
@triton.jit
|
||||
def _fwd_copy_kv_cache_dest(
|
||||
kv_cache_ptr, dest_index_ptr,
|
||||
out,
|
||||
stride_k_bs,
|
||||
stride_k_h,
|
||||
stride_k_d,
|
||||
stride_o_bs,
|
||||
stride_o_h,
|
||||
stride_o_d,
|
||||
head_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr
|
||||
):
|
||||
cur_index = tl.program_id(0)
|
||||
offs_h = tl.arange(0, BLOCK_HEAD)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
dest_index = tl.load(dest_index_ptr + cur_index)
|
||||
|
||||
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
|
||||
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
|
||||
|
||||
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
|
||||
o_ptrs = out + dest_index * stride_o_bs + o_offsets
|
||||
|
||||
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
|
||||
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
||||
seq_len = dest_index_ptr.shape[0]
|
||||
head_num = k_ptr.shape[1]
|
||||
head_dim = k_ptr.shape[2]
|
||||
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
|
||||
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
|
||||
|
||||
num_warps = 2
|
||||
|
||||
_fwd_copy_kv_cache_dest[(seq_len,)](
|
||||
k_ptr, dest_index_ptr, out,
|
||||
k_ptr.stride(0),
|
||||
k_ptr.stride(1),
|
||||
k_ptr.stride(2),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
head_num,
|
||||
BLOCK_DMODEL=head_dim,
|
||||
BLOCK_HEAD=triton.next_power_of_2(head_num),
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ except ImportError:
|
|||
|
||||
if HAS_TRITON:
|
||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from .softmax_kernel import softmax_kernel
|
||||
from .softmax import softmax_kernel
|
||||
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
|
@ -156,54 +156,3 @@ if HAS_TRITON:
|
|||
q, k, v, input_mask, scale)
|
||||
|
||||
return data_output_triton
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
num_rows, num_cols = input.shape
|
||||
block_size = max(triton.next_power_of_2(num_cols), 2)
|
||||
num_warps = 16
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel_2[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
num_rows, num_cols = input.shape
|
||||
block_size = max(triton.next_power_of_2(num_cols), 2)
|
||||
num_warps = 16
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
|
@ -1,44 +0,0 @@
|
|||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
|
@ -7,7 +7,7 @@ from transformers.modeling_outputs import (
|
|||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -472,9 +472,19 @@ class LlamaInferenceForwards:
|
|||
def get_llama_flash_attention_forward():
|
||||
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
try:
|
||||
from vllm import pos_encoding_ops
|
||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("fall back to original rotary_embedding_neox of huggingface")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -496,6 +506,11 @@ def get_llama_flash_attention_forward():
|
|||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
if HAS_VLLM_KERNERL:
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
|
@ -531,3 +546,32 @@ def get_llama_flash_attention_forward():
|
|||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_vllm_rmsnorm_forward():
|
||||
try:
|
||||
from vllm import layernorm_ops
|
||||
rms_norm = layernorm_ops.rms_norm
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("please install vllm kernels to install rmsnorm")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
if HAS_VLLM_KERNERL:
|
||||
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
x = hidden_states
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
return _vllm_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
from vllm import layernorm_ops
|
||||
rms_norm = layernorm_ops.rms_norm
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("please install vllm kernels to install rmsnorm")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
|
||||
x = hidden_states
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
||||
def test_rmsnorm():
|
||||
data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
|
||||
hg_rms = LlamaRMSNorm(64)
|
||||
hg_rms = hg_rms.half().cuda()
|
||||
out_torch = hg_rms(data)
|
||||
out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
|
||||
|
||||
check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
|
||||
assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rmsnorm()
|
|
@ -0,0 +1,156 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
|
||||
|
||||
try:
|
||||
from vllm import pos_encoding_ops
|
||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("fall back to original rotary_embedding_neox of huggingface")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbeddingNeox(nn.Module):
|
||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||
return query, key
|
||||
|
||||
def run_rotary_embedding_neox(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
dtype: torch.dtype,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
dim=rotary_dim,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
key.view(num_tokens, num_heads, head_size),
|
||||
)
|
||||
ref_query = ref_query.view(num_tokens, num_heads * head_size)
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
|
||||
def test_rotary_embedding():
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=1024,
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
max_position=8192,
|
||||
rotary_dim=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_embedding()
|
|
@ -0,0 +1,57 @@
|
|||
import pytest
|
||||
import math
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
|
||||
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_bloom_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
|
||||
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"
|
||||
|
||||
latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi)
|
||||
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)
|
||||
|
||||
print("the triton op latency is {} ms".format(str(latency_1)))
|
||||
print("the torch op latency is {} ms".format(str(latency_2)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom_context_attention()
|
|
@ -0,0 +1,41 @@
|
|||
import pytest
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from tests.test_kernels.triton.utils import benchmark
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_kv_cache_copy_op():
|
||||
|
||||
B_NTX = 32 * 2048
|
||||
head_num = 8
|
||||
head_dim = 64
|
||||
|
||||
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
|
||||
|
||||
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
|
||||
|
||||
copy_kv_cache_to_dest(cache, dest_index, dest_data)
|
||||
|
||||
assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
|
||||
|
||||
latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data)
|
||||
print("the average latency is {} ms".format(str(latency)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_kv_cache_copy_op()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import pytest
|
||||
import math
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
|
||||
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_llama_context_attention():
|
||||
bs = 4
|
||||
head_num = 8
|
||||
seq_len = 1024
|
||||
head_dim = 64
|
||||
|
||||
query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
|
||||
|
||||
max_input_len = seq_len
|
||||
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
|
||||
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
for i in range(bs):
|
||||
b_start[i] = i * seq_len
|
||||
b_len[i] = seq_len
|
||||
|
||||
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
|
||||
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
|
||||
|
||||
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
|
||||
|
||||
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"
|
||||
|
||||
latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len)
|
||||
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)
|
||||
|
||||
print("the triton op latency is {} ms".format(str(latency_1)))
|
||||
print("the torch op latency is {} ms".format(str(latency_2)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_context_attention()
|
|
@ -4,12 +4,11 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.triton.ops import self_attention_compute_using_triton
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
|
||||
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
|
@ -17,7 +16,7 @@ except ImportError:
|
|||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_qkv_matmul():
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
||||
scale = 1.2
|
||||
|
@ -106,7 +105,7 @@ def self_attention_compute_using_torch(qkv,
|
|||
|
||||
return res.view(batches, -1, d_model), score_output, softmax_output
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_self_atttention_test():
|
||||
|
||||
qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
|
|
@ -3,11 +3,19 @@ from packaging import version
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.kernel.triton.ops import softmax
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from colossalai.kernel.triton.softmax import softmax
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
|
||||
def test_softmax_op():
|
||||
data_samples = [
|
||||
torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
|
|
@ -0,0 +1,50 @@
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def benchmark(func, *args):
|
||||
starter, ender = torch.cuda.Event(
|
||||
enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
repetitions = 300
|
||||
|
||||
for i in range(10):
|
||||
func(*args)
|
||||
|
||||
timings = np.zeros((repetitions, 1))
|
||||
with torch.no_grad():
|
||||
for rep in range(repetitions):
|
||||
starter.record()
|
||||
func(*args)
|
||||
ender.record()
|
||||
# WAIT FOR GPU SYNC
|
||||
torch.cuda.synchronize()
|
||||
curr_time = starter.elapsed_time(ender)
|
||||
timings[rep] = curr_time
|
||||
|
||||
mean_syn = np.sum(timings) / repetitions
|
||||
return mean_syn
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
'''
|
||||
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
'''
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
sm_scale = 1/math.sqrt(head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)
|
||||
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
return output
|
Loading…
Reference in New Issue