mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
660 lines
20 KiB
660 lines
20 KiB
import warnings
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
"""
|
|
# Base autotune if needed
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=4),
|
|
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":8,},num_warps=8),
|
|
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":8,},num_warps=8),
|
|
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=16),
|
|
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=32),
|
|
triton.Config({'BLOCK_HEAD':16,"BLOCK_TOKENS":16,},num_warps=4),
|
|
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":16,},num_warps=8),
|
|
],
|
|
key=['HEAD_DIM','q_total_tokens','Q_HEAD_NUM']
|
|
)
|
|
"""
|
|
|
|
|
|
@triton.jit
|
|
def rotary_embedding_kernel(
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
q_total_tokens,
|
|
Q_HEAD_NUM: tl.constexpr,
|
|
KV_GROUP_NUM: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
BLOCK_TOKENS: tl.constexpr, # token range length
|
|
):
|
|
cur_head_idx = tl.program_id(0)
|
|
cur_token_block_idx = tl.program_id(1)
|
|
|
|
tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
|
|
|
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
|
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
|
|
|
off_q0 = (
|
|
tokens_range[:, None, None] * q_token_stride
|
|
+ cur_head_idx * q_head_stride
|
|
+ dim_range0[None, None, :] * head_dim_stride
|
|
)
|
|
off_q1 = (
|
|
tokens_range[:, None, None] * q_token_stride
|
|
+ cur_head_idx * q_head_stride
|
|
+ dim_range1[None, None, :] * head_dim_stride
|
|
)
|
|
loaded_q0 = tl.load(
|
|
q + off_q0,
|
|
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
loaded_q1 = tl.load(
|
|
q + off_q1,
|
|
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
|
|
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
|
|
|
|
tl.store(
|
|
q + off_q0,
|
|
out_q0,
|
|
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
tl.store(
|
|
q + off_q1,
|
|
out_q1,
|
|
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
|
|
handle_kv = cur_head_idx % KV_GROUP_NUM == 0
|
|
if handle_kv:
|
|
k_head_idx = cur_head_idx // KV_GROUP_NUM
|
|
off_k0 = (
|
|
tokens_range[:, None, None] * k_token_stride
|
|
+ k_head_idx * k_head_stride
|
|
+ dim_range0[None, None, :] * head_dim_stride
|
|
)
|
|
off_k1 = (
|
|
tokens_range[:, None, None] * k_token_stride
|
|
+ k_head_idx * k_head_stride
|
|
+ dim_range1[None, None, :] * head_dim_stride
|
|
)
|
|
loaded_k0 = tl.load(
|
|
k + off_k0,
|
|
mask=(tokens_range[:, None, None] < q_total_tokens),
|
|
other=0.0,
|
|
)
|
|
loaded_k1 = tl.load(
|
|
k + off_k1,
|
|
mask=(tokens_range[:, None, None] < q_total_tokens),
|
|
other=0.0,
|
|
)
|
|
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
|
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
|
|
tl.store(
|
|
k + off_k0,
|
|
out_k0,
|
|
mask=(tokens_range[:, None, None] < q_total_tokens),
|
|
)
|
|
tl.store(
|
|
k + off_k1,
|
|
out_k1,
|
|
mask=(tokens_range[:, None, None] < q_total_tokens),
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def fused_rotary_embedding_kernel(
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
kv_cache,
|
|
BLOCK_TABLES,
|
|
context_lengths,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
cacheb_stride,
|
|
cacheh_stride,
|
|
cachebs_stride,
|
|
cached_stride,
|
|
bts_stride,
|
|
btb_stride,
|
|
block_size,
|
|
q_total_tokens,
|
|
Q_HEAD_NUM: tl.constexpr,
|
|
K_HEAD_NUM: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
BLOCK_HEAD: tl.constexpr,
|
|
BLOCK_TOKENS: tl.constexpr,
|
|
):
|
|
block_head_index = tl.program_id(0)
|
|
block_token_index = tl.program_id(1)
|
|
|
|
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
|
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
|
|
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
|
|
|
off_q0 = (
|
|
tokens_range[:, None, None] * q_token_stride
|
|
+ head_range[None, :, None] * q_head_stride
|
|
+ dim_range0[None, None, :] * head_dim_stride
|
|
)
|
|
off_q1 = (
|
|
tokens_range[:, None, None] * q_token_stride
|
|
+ head_range[None, :, None] * q_head_stride
|
|
+ dim_range1[None, None, :] * head_dim_stride
|
|
)
|
|
off_k0 = (
|
|
tokens_range[:, None, None] * k_token_stride
|
|
+ head_range[None, :, None] * k_head_stride
|
|
+ dim_range0[None, None, :] * head_dim_stride
|
|
)
|
|
off_k1 = (
|
|
tokens_range[:, None, None] * k_token_stride
|
|
+ head_range[None, :, None] * k_head_stride
|
|
+ dim_range1[None, None, :] * head_dim_stride
|
|
)
|
|
|
|
loaded_q0 = tl.load(
|
|
q + off_q0,
|
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
loaded_q1 = tl.load(
|
|
q + off_q1,
|
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
|
|
loaded_k0 = tl.load(
|
|
k + off_k0,
|
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
|
|
loaded_k1 = tl.load(
|
|
k + off_k1,
|
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
other=0.0,
|
|
)
|
|
|
|
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
|
|
|
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
|
|
|
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
|
|
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
|
|
|
|
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
|
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
|
|
|
|
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1
|
|
|
|
last_block_idx = past_kv_seq_len // block_size
|
|
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
|
|
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
|
|
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
|
|
|
kv_range0 = (
|
|
block_ids[:, None, None, None] * cacheb_stride
|
|
+ head_range[None, :, None, None] * cacheh_stride
|
|
+ offsets_in_last_block[:, None, None, None]
|
|
+ dim_range0[None, None, None, :] * cached_stride
|
|
)
|
|
kv_range1 = (
|
|
block_ids[:, None, None, None] * cacheb_stride
|
|
+ head_range[None, :, None, None] * cacheh_stride
|
|
+ offsets_in_last_block[:, None, None, None]
|
|
+ dim_range1[None, None, None, :] * cached_stride
|
|
)
|
|
|
|
tl.store(
|
|
kv_cache + kv_range0,
|
|
out_k0[:, :, None, :],
|
|
)
|
|
tl.store(
|
|
kv_cache + kv_range1,
|
|
out_k1[:, :, None, :],
|
|
)
|
|
|
|
# concat
|
|
tl.store(
|
|
q + off_q0,
|
|
out_q0,
|
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
tl.store(
|
|
q + off_q1,
|
|
out_q1,
|
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
tl.store(
|
|
k + off_k0,
|
|
out_k0,
|
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
tl.store(
|
|
k + off_k1,
|
|
out_k1,
|
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def fused_rotary_embedding_kernel_v2(
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
kv_cache,
|
|
BLOCK_TABLES,
|
|
context_lengths,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
cacheb_stride,
|
|
cacheh_stride,
|
|
cachebs_stride,
|
|
cached_stride,
|
|
bts_stride,
|
|
btb_stride,
|
|
block_size,
|
|
q_total_tokens,
|
|
Q_HEAD_NUM: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
):
|
|
block_head_index = tl.program_id(0)
|
|
if block_head_index >= Q_HEAD_NUM:
|
|
return
|
|
block_token_index = tl.program_id(1)
|
|
|
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
|
|
|
off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
|
|
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
|
|
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
|
|
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride
|
|
|
|
loaded_q0 = tl.load(
|
|
q + off_q0,
|
|
)
|
|
loaded_q1 = tl.load(
|
|
q + off_q1,
|
|
)
|
|
|
|
loaded_k0 = tl.load(
|
|
k + off_k0,
|
|
)
|
|
|
|
loaded_k1 = tl.load(
|
|
k + off_k1,
|
|
)
|
|
|
|
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
|
|
|
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
|
|
|
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
|
|
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
|
|
|
|
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
|
|
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
|
|
|
|
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
|
|
|
|
last_block_idx = past_kv_seq_len // block_size
|
|
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
|
|
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
|
|
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
|
|
|
kv_range0 = (
|
|
block_ids * cacheb_stride
|
|
+ block_head_index * cacheh_stride
|
|
+ offsets_in_last_block
|
|
+ dim_range0 * cached_stride
|
|
)
|
|
kv_range1 = (
|
|
block_ids * cacheb_stride
|
|
+ block_head_index * cacheh_stride
|
|
+ offsets_in_last_block
|
|
+ dim_range1 * cached_stride
|
|
)
|
|
|
|
tl.store(
|
|
kv_cache + kv_range0,
|
|
out_k0,
|
|
)
|
|
tl.store(
|
|
kv_cache + kv_range1,
|
|
out_k1,
|
|
)
|
|
|
|
# concat
|
|
tl.store(
|
|
q + off_q0,
|
|
out_q0,
|
|
)
|
|
tl.store(
|
|
q + off_q1,
|
|
out_q1,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def decoding_fused_rotary_embedding_kernel(
|
|
q,
|
|
k,
|
|
v,
|
|
cos,
|
|
sin,
|
|
k_cache,
|
|
v_cache,
|
|
BLOCK_TABLES,
|
|
context_lengths,
|
|
x,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
kcb_stride,
|
|
kch_stride,
|
|
kcsplit_x_stride,
|
|
kcs_stride,
|
|
kcd_stride,
|
|
vcb_stride,
|
|
vch_stride,
|
|
vcs_stride,
|
|
vcd_stride,
|
|
bts_stride,
|
|
btb_stride,
|
|
block_size,
|
|
KV_GROUP_NUM: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
):
|
|
cur_head_idx = tl.program_id(0)
|
|
cur_token_idx = tl.program_id(1)
|
|
|
|
dim_range = tl.arange(0, HEAD_DIM)
|
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
|
|
|
off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride
|
|
off_q0 = off_q + dim_range0 * head_dim_stride
|
|
off_q1 = off_q + dim_range1 * head_dim_stride
|
|
|
|
loaded_q0 = tl.load(q + off_q0)
|
|
loaded_q1 = tl.load(q + off_q1)
|
|
off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride
|
|
loaded_cos = tl.load(cos + off_cos_sin)
|
|
loaded_sin = tl.load(sin + off_cos_sin)
|
|
|
|
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
|
|
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
|
|
tl.store(q + off_q0, out_q0)
|
|
tl.store(q + off_q1, out_q1)
|
|
|
|
handle_kv = cur_head_idx % KV_GROUP_NUM == 0
|
|
if handle_kv:
|
|
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
|
|
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
|
|
off_k0 = off_kv + dim_range0 * head_dim_stride
|
|
off_k1 = off_kv + dim_range1 * head_dim_stride
|
|
loaded_k0 = tl.load(k + off_k0)
|
|
loaded_k1 = tl.load(k + off_k1)
|
|
|
|
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
|
|
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos
|
|
|
|
# NOTE The precondition here is that it's only for unpadded inputs during decoding stage,
|
|
# and so that we could directly use the token index as the sequence index
|
|
past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1
|
|
|
|
last_block_idx = past_kv_seq_len // block_size
|
|
block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
|
|
offsets_in_last_block = past_kv_seq_len % block_size
|
|
offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride
|
|
k_range0 = (
|
|
offsets_cache_base
|
|
+ offsets_in_last_block * kcs_stride
|
|
+ (dim_range0 // x) * kcsplit_x_stride
|
|
+ (dim_range0 % x) * kcd_stride
|
|
)
|
|
k_range1 = (
|
|
offsets_cache_base
|
|
+ offsets_in_last_block * kcs_stride
|
|
+ (dim_range1 // x) * kcsplit_x_stride
|
|
+ (dim_range1 % x) * kcd_stride
|
|
)
|
|
tl.store(k_cache + k_range0, out_k0)
|
|
tl.store(k_cache + k_range1, out_k1)
|
|
|
|
off_v = off_kv + dim_range * head_dim_stride
|
|
loaded_v = tl.load(v + off_v)
|
|
v_range = (
|
|
block_ids * vcb_stride
|
|
+ cur_k_head_idx * vch_stride
|
|
+ offsets_in_last_block * vcs_stride
|
|
+ dim_range * vcd_stride
|
|
)
|
|
tl.store(v_cache + v_range, loaded_v)
|
|
|
|
|
|
def rotary_embedding(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
k_cache: Optional[torch.Tensor] = None,
|
|
block_tables: Optional[torch.Tensor] = None,
|
|
kv_lengths: Optional[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
q: query tensor, [total_tokens, head_num, head_dim]
|
|
k: key tensor, [total_tokens, kv_head_num, head_dim]
|
|
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
|
sin: sine for rotary embedding, [max_position_len, head_dim]
|
|
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
|
|
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
|
|
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
|
|
"""
|
|
q_total_tokens, q_head_num, head_dim = q.shape
|
|
assert q.size(0) == k.size(0)
|
|
BLOCK_TOKENS = 4
|
|
|
|
if head_dim >= 512:
|
|
num_warps = 16
|
|
elif head_dim >= 256:
|
|
num_warps = 8
|
|
else:
|
|
num_warps = 4
|
|
|
|
k_head_num = k.size(1)
|
|
q_token_stride, q_head_stride, head_dim_stride = q.stride()
|
|
k_token_stride, k_head_stride, _ = k.stride()
|
|
cos_token_stride, cos_stride = cos.stride()
|
|
|
|
assert q_head_num % k_head_num == 0
|
|
kv_group_num = q_head_num // k_head_num
|
|
|
|
if k_cache == None:
|
|
grid = lambda META: (
|
|
q_head_num,
|
|
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
|
|
)
|
|
rotary_embedding_kernel[grid](
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
q_total_tokens,
|
|
Q_HEAD_NUM=q_head_num,
|
|
KV_GROUP_NUM=kv_group_num,
|
|
HEAD_DIM=head_dim,
|
|
BLOCK_TOKENS=BLOCK_TOKENS,
|
|
num_warps=num_warps,
|
|
)
|
|
else:
|
|
warnings.warn("Fused rotary embedding Triton kernel will be deprecated as the new kcache layout is supported")
|
|
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
|
|
fused_rotary_embedding_kernel_v2[grid](
|
|
q,
|
|
k,
|
|
cos,
|
|
sin,
|
|
k_cache,
|
|
block_tables,
|
|
kv_lengths,
|
|
q_token_stride,
|
|
q_head_stride,
|
|
k_token_stride,
|
|
k_head_stride,
|
|
head_dim_stride,
|
|
cos_token_stride,
|
|
cos_stride,
|
|
k_cache.stride(0),
|
|
k_cache.stride(1),
|
|
k_cache.stride(2),
|
|
k_cache.stride(3),
|
|
block_tables.stride(0),
|
|
block_tables.stride(1),
|
|
k_cache.size(-2),
|
|
q_total_tokens,
|
|
Q_HEAD_NUM=q_head_num,
|
|
HEAD_DIM=head_dim,
|
|
num_warps=num_warps,
|
|
)
|
|
return
|
|
|
|
|
|
def decoding_fused_rotary_embedding(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
k_cache: Optional[torch.Tensor] = None,
|
|
v_cache: Optional[torch.Tensor] = None,
|
|
block_tables: Optional[torch.Tensor] = None,
|
|
kv_lengths: Optional[torch.Tensor] = None,
|
|
use_new_kcache_layout: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
q: query tensor, [total_tokens, head_num, head_dim]
|
|
k: key tensor, [total_tokens, kv_head_num, head_dim]
|
|
v: value tensor, [total tokens, kv_head_num, head_dim]
|
|
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
|
sin: sine for rotary embedding, [max_position_len, head_dim]
|
|
k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]
|
|
v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]
|
|
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
|
|
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
|
|
"""
|
|
q_total_tokens, q_head_num, head_dim = q.shape
|
|
assert q.size(0) == k.size(0) == v.size(0)
|
|
|
|
if head_dim >= 512:
|
|
num_warps = 16
|
|
elif head_dim >= 256:
|
|
num_warps = 8
|
|
else:
|
|
num_warps = 4
|
|
k_head_num = k.size(1)
|
|
kv_group_num = q_head_num // k_head_num
|
|
|
|
# For KCache and VCache with the same layout
|
|
x = head_dim
|
|
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
|
|
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
|
if use_new_kcache_layout:
|
|
assert (
|
|
k_cache.dim() == 5
|
|
and k_cache.shape[1] == v_cache.shape[1]
|
|
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
|
|
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
|
|
x = k_cache.size(-1)
|
|
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
|
|
|
|
grid = (q_head_num, q_total_tokens)
|
|
decoding_fused_rotary_embedding_kernel[grid](
|
|
q,
|
|
k,
|
|
v,
|
|
cos,
|
|
sin,
|
|
k_cache,
|
|
v_cache,
|
|
block_tables,
|
|
kv_lengths,
|
|
x,
|
|
q.stride(0),
|
|
q.stride(1),
|
|
k.stride(0),
|
|
k.stride(1),
|
|
q.stride(2),
|
|
cos.stride(0),
|
|
cos.stride(1),
|
|
k_cache.stride(0),
|
|
k_cache.stride(1),
|
|
kcsplit_x_stride,
|
|
kcs_stride,
|
|
kcd_stride,
|
|
v_cache.stride(0),
|
|
v_cache.stride(1),
|
|
v_cache.stride(2),
|
|
v_cache.stride(3),
|
|
block_tables.stride(0),
|
|
block_tables.stride(1),
|
|
k_cache.size(-2),
|
|
KV_GROUP_NUM=kv_group_num,
|
|
HEAD_DIM=head_dim,
|
|
num_warps=num_warps,
|
|
)
|
|
return
|