[Inference/NFC] Clean outdated inference tests and deprecated kernels (#5159)

* [inference/nfc] remove outdated inference tests

* remove outdated kernel tests

* remove deprecated triton kernels

* remove imports from deprecated kernels
pull/5258/head
Yuanheng Zhao 2023-12-05 15:12:57 +08:00 committed by FrankLeeeee
parent 56e75eeb06
commit 2bb92243d4
18 changed files with 0 additions and 2543 deletions

View File

@ -8,24 +8,12 @@ except ImportError:
# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd
__all__ = [
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
"copy_kv_cache_to_dest",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
"smooth_llama_context_attn_fwd",
"smooth_token_attention_fwd",
]

View File

@ -1,393 +0,0 @@
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
if triton.__version__ < "2.1.0":
@triton.jit
def _context_flash_attention_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if alibi_ptr is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
else:
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
@triton.jit
def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
Out,
kv_group_num,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
if kv_group_num is not None:
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
if kv_group_num is None or kv_group_num == 1:
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
else:
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if Alibi is not None:
alibi_m = tl.load(Alibi + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if Alibi is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
if triton.__version__ < "2.1.0":
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
alibi,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_context_flash_attention_kernel_2[grid](
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
o,
None,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
if triton.__version__ < "2.1.0":
_context_flash_attention_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
kv_group_num = q.shape[1] // k.shape[1]
_context_flash_attention_kernel_2[grid](
q,
k,
v,
sm_scale,
None,
b_start_loc,
b_seq_len,
o,
kv_group_num,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,)
return

View File

@ -1,71 +0,0 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit
def _fwd_copy_kv_cache_dest(
kv_cache_ptr,
dest_index_ptr,
out,
stride_k_bs,
stride_k_h,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_d,
head_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
offs_d = tl.arange(0, BLOCK_DMODEL)
dest_index = tl.load(dest_index_ptr + cur_index)
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
o_ptrs = out + dest_index * stride_o_bs + o_offsets
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return
# adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0]
head_num = k_ptr.shape[1]
head_dim = k_ptr.shape[2]
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr,
dest_index_ptr,
out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr.stride(2),
out.stride(0),
out.stride(1),
out.stride(2),
head_num,
BLOCK_DMODEL=head_dim,
BLOCK_HEAD=triton.next_power_of_2(head_num),
num_warps=num_warps,
num_stages=2,
)
return

View File

@ -1,50 +0,0 @@
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch
try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True
except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
HAS_LIGHTLLM_KERNEL = False
if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
calcu_shape1 = (batch_size, q_head_num, head_dim)
if getattr(infer_state, 'mid_o', None) is None:
infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float32,
device="cuda")
infer_state.mid_o_logexpsum = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float32,
device="cuda")
mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum
flash_decode_stage1(q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.block_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)

View File

@ -1,117 +0,0 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl
@triton.jit
def _rotary_kernel(
q,
input_scale,
output_scale,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
q0 = q0.to(tl.float32) * input_scale
q1 = q1.to(tl.float32) * input_scale
out0 = (q0 * cos - q1 * sin) / output_scale
out1 = (q0 * sin + q1 * cos) / output_scale
out0 = out0.to(tl.int8)
out1 = out1.to(tl.int8)
tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return
@torch.no_grad()
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4
_rotary_kernel[grid](
q,
input_scale,
output_scale,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return

View File

@ -1,164 +0,0 @@
import torch
try:
import triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax import softmax_kernel
# adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312
def self_attention_forward_without_fusion(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
):
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
scale: the float scale value which is used to multiply with Q*K^T before doing softmax
Return:
output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
"""
assert len(q.shape) == 4, "the shape of q val must be 4"
batches, M, H, K = q.shape
assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"
N = k.shape[1]
# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
q,
k,
score_output,
M,
N,
K,
q.stride(0),
q.stride(2),
q.stride(1),
q.stride(3),
k.stride(0),
k.stride(2),
k.stride(3),
k.stride(1),
score_output.stride(0),
score_output.stride(1),
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape
score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4
softmax_kernel[(n_rows,)](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
mask_ptr=input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)
else:
# NOTE: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)
batches, H, M, K = softmax_output.shape
N = v.shape[-1]
output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
softmax_output,
v,
output,
M,
N,
K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
softmax_output.stride(3),
v.stride(0),
v.stride(2),
v.stride(1),
v.stride(3),
output.stride(0),
output.stride(2),
output.stride(1),
output.stride(3),
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
scale=-1,
)
return output.view(batches, -1, d_model)
# modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212
def self_attention_compute_using_triton(
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
):
assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale)
return data_output_triton

View File

@ -1,652 +0,0 @@
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
"""
this functions are modified from https://github.com/ModelTC/lightllm
"""
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
@triton.jit
def _context_flash_attention_kernel(
Q,
K,
V,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if alibi_ptr is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
v = v.to(tl.float16) * v_input_scale.to(tl.float16)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_context_flash_attention_kernel[grid](
q,
k,
v,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_kernel(
Q,
K,
q_input_scale,
k_input_scale,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
K,
q_input_scale,
k_input_scale,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
@torch.no_grad()
def token_attn_fwd_1(
q,
k,
attn_out,
q_input_scale,
k_input_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None,
):
BLOCK = 32
# shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
assert q_head_dim == k_head_dim
assert k_head_dim in {16, 32, 64, 128}
sm_scale = 1.0 / (k_head_dim**0.5)
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
num_warps = 4 if k_head_dim <= 64 else 8
num_warps = 2
if alibi is not None:
_token_attn_1_alibi_kernel[grid](
q,
k,
q_input_scale,
k_input_scale,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_token_attn_1_kernel[grid](
q,
k,
q_input_scale,
k_input_scale,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
logics_head_dim_stride,
logics_batch_stride,
prob_head_dim_stride,
prob_batch_stride,
BLOCK_SIZE: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(
softmax_logics
+ current_head * logics_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(
softmax_prob_out
+ current_head * prob_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len,
)
return
@torch.no_grad()
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_token_attn_softmax_fwd[(batch, head_num)](
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
softmax_logics.stride(0),
softmax_logics.stride(1),
softmax_prob_out.stride(0),
softmax_prob_out.stride(1),
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return
# Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@triton.jit
def _token_attn_2_kernel(
Prob,
V,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
prob_head_dim_stride,
prob_batch_stride,
v_batch_stride,
v_head_stride,
v_head_dim_stride,
attn_out_batch_stride,
attn_out_head_stride,
attn_out_head_dim_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(
Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_loc = tl.load(
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0,
)
v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
off_o = (
current_batch * attn_out_batch_stride
+ current_head * attn_out_head_stride
+ offs_d * attn_out_head_dim_stride
)
out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc)
return
@torch.no_grad()
def token_attn_fwd_2(
prob,
v,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
):
if triton.__version__ >= "2.1.0":
BLOCK = 128
else:
BLOCK = 64
batch, head = kv_cache_loc.shape[0], v.shape[1]
grid = (batch, head)
num_warps = 4
dim = v.shape[-1]
_token_attn_2_kernel[grid](
prob,
v,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
prob.stride(0),
prob.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
attn_out.stride(0),
attn_out.stride(1),
attn_out.stride(2),
HEAD_DIM=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
def smooth_token_attention_fwd(
q,
k,
v,
attn_out,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None,
):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
total_token_num = k.shape[0]
att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda")
token_attn_fwd_1(
q.view(calcu_shape1),
k,
att_m_tensor,
q_input_scale,
k_input_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi,
)
prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
token_attn_fwd_2(
prob,
v,
attn_out.view(calcu_shape1),
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
prob = None
return

View File

@ -1,238 +0,0 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
HAS_TRITON_TOKEN_ATTENTION = True
except ImportError:
print("unable to import lightllm kernels")
HAS_TRITON_TOKEN_ATTENTION = False
if HAS_TRITON:
@torch.no_grad()
def token_attention_fwd(
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
total_token_num = k.shape[0]
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
if alibi is None:
lightllm_llama_token_att_fwd(
q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
else:
lightllm_bloom_token_att_fwd(
q.view(calcu_shape1),
k,
att_m_tensor,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
prob = torch.empty_like(att_m_tensor)
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
lightllm_llama_token_att_fwd2(
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
)
prob = None
return
class Llama2TokenAttentionForwards:
@staticmethod
@triton.jit
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8
def _fwd_kernel(
Logics,
V,
Out,
B_Loc,
B_Start_Loc,
B_Seqlen,
max_input_len,
stride_logic_h,
stride_logic_bs,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_b_loc_b,
stride_b_loc_s,
other_kv_index, # avoid nan information
kv_group_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s
v_ptrs = V + off_v
e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(
B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
)
qk = tl.load(
Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=start_n + offs_n < cur_batch_seq_len,
other=float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max
acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
return
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36
@staticmethod
@torch.no_grad()
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head)
kv_group_num = logics.shape[0] // v.shape[1]
num_warps = 1
Llama2TokenAttentionForwards._fwd_kernel[grid](
logics,
v,
o,
b_loc,
b_start_loc,
b_seq_len,
max_input_len,
logics.stride(0),
logics.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
b_loc.stride(0),
b_loc.stride(1),
other_kv_index,
kv_group_num,
BLOCK_DMODEL=v.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3,
)
return
# this is the interface of llama2 attn forward
@staticmethod
@torch.no_grad()
def token_attn(
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index
):
total_token_num = k.shape[0]
batch_size, head_num, head_dim = q.shape
calcu_shape1 = (batch_size, head_num, head_dim)
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
lightllm_llama_token_att_fwd(
q,
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
lightllm_llama_token_softmax_fwd(
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
att_m_tensor = None
lightllm_llama_token_att_fwd2(
prob,
v,
attn_out.view(calcu_shape1),
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
prob = None
return
elif triton.__version__ >= "2.1.0":
Llama2TokenAttentionForwards.token_softmax_reducev_fwd(
att_m_tensor,
v,
attn_out.view(calcu_shape1),
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
other_kv_index,
)
else:
raise Exception("not support triton version")

View File

@ -1,121 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.BloomForCausalLM(
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@ -1,129 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
chatglm_config = ChatGLMConfig(
num_layers=2,
vocab_size=20000,
use_cache=True,
multi_query_attention=True,
multi_query_group_num=2,
num_attention_heads=8,
hidden_size=1024,
)
model = ChatGLMForConditionalGeneration(chatglm_config)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@ -1,126 +0,0 @@
import importlib.util
import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
import importlib.util
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
)
)
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
output = engine.generate(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
@parameterize("tp_size", [1])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()
def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)
if __name__ == "__main__":
test_pipeline_inference()

View File

@ -1,66 +0,0 @@
import os
import pytest
import torch
from packaging import version
from colossalai.inference.kv_cache import MemoryManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
BATCH_SIZE = 4
INPUT_LEN = 16
OUTPUT_LEN = 8
LAYER_NUM = 4
HEAD_NUM = 32
HEAD_DIM = 128
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
disable_existing_loggers()
size = batch_size * (input_len + output_len)
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
key_buffers = kvcache_manager.key_buffer
value_buffers = kvcache_manager.value_buffer
assert len(key_buffers) == len(value_buffers) == layer_num
assert key_buffers[0].shape == value_buffers[0].shape
# required size exceeds the maximum allocated size
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
assert invalid_locs is None
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
total_token_prefill = batch_size * input_len
prefill_locs = kvcache_manager.alloc(total_token_prefill)
kvcache_manager.free_all()
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
assert torch.equal(prefill_locs, prefill_locs_contiguous)
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False)
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
spawn(
create_cache_manager,
4,
batch_size=BATCH_SIZE,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
layer_num=LAYER_NUM,
head_num=HEAD_NUM,
head_dim=HEAD_DIM,
)
if __name__ == "__main__":
test_cache_manager_dist()

View File

@ -1,52 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton import bloom_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_bloom_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_bloom_context_attention()

View File

@ -1,39 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_kv_cache_copy_op():
B_NTX = 32 * 2048
head_num = 8
head_dim = 64
cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)
dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
copy_kv_cache_to_dest(cache, dest_index, dest_data)
assert torch.allclose(
cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3
), "copy_kv_cache_to_dest outputs from triton and torch are not matched"
if __name__ == "__main__":
test_kv_cache_copy_op()

View File

@ -1,50 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton import llama_context_attn_fwd
from tests.test_infer_ops.triton.kernel_utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_llama_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64
query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
max_input_len = seq_len
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)
for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_llama_context_attention()

View File

@ -1,143 +0,0 @@
import pytest
import torch
import torch.nn.functional as F
from packaging import version
try:
import triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_qkv_matmul():
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
scale = 1.2
head_size = 32
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model : d_model * 2]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
q_copy = q.clone()
k_copy = k.clone()
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
k = torch.transpose(k, 2, 3).contiguous()
torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k)
torch_ouput *= 1.2
q, k = q_copy, k_copy
batches, M, H, K = q.shape
N = k.shape[1]
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
K = q.shape[3]
qkv_gemm_4d_kernel[grid](
q,
k,
score_output,
M,
N,
K,
q.stride(0),
q.stride(2),
q.stride(1),
q.stride(3),
k.stride(0),
k.stride(2),
k.stride(3),
k.stride(1),
score_output.stride(0),
score_output.stride(1),
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
assert check is True, "the outputs of triton and torch are not matched"
def self_attention_compute_using_torch(qkv, input_mask, scale, head_size):
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
q = torch.transpose(q, 1, 2).contiguous()
k = torch.transpose(k, 1, 2).contiguous()
v = torch.transpose(v, 1, 2).contiguous()
k = torch.transpose(k, -1, -2).contiguous()
score_output = torch.einsum("bnij,bnjk->bnik", q, k)
score_output *= scale
softmax_output = F.softmax(score_output, dim=-1)
res = torch.einsum("bnij,bnjk->bnik", softmax_output, v)
res = torch.transpose(res, 1, 2)
res = res.contiguous()
return res.view(batches, -1, d_model), score_output, softmax_output
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_self_atttention_test():
qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16)
data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
qkv.clone(), input_mask=None, scale=1.2, head_size=32
)
data_output_triton = self_attention_compute_using_triton(
qkv.clone(),
alibi=None,
head_size=32,
scale=1.2,
input_mask=None,
layer_past=None,
use_flash=False,
triangular=True,
)
check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
assert check is True, "the triton output is not matched with torch output"
if __name__ == "__main__":
test_qkv_matmul()
test_self_atttention_test()

View File

@ -1,72 +0,0 @@
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
import importlib.util
HAS_LIGHTLLM_KERNEL = True
if importlib.util.find_spec("lightllm") is None:
HAS_LIGHTLLM_KERNEL = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
prob = torch.softmax(logics, dim=1)
prob = prob.view(bs, seqlen, num_head, 1)
return torch.sum(prob * xv, dim=1, keepdim=False)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL,
reason="triton requires cuda version to be higher than 11.4 or not install lightllm",
)
def test():
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
dtype = torch.float16
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
kv_cache_seq_len[:] = seq_len
kv_cache_start_loc[0] = 0
kv_cache_start_loc[1] = seq_len
kv_cache_start_loc[2] = 2 * seq_len
kv_cache_start_loc[3] = 3 * seq_len
for i in range(Z):
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test()

View File

@ -1,48 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_softmax():
import torch
batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128
dtype = torch.float16
Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len)
torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len)
o = ProbOut
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_softmax()