mirror of https://github.com/hpcaitech/ColossalAI
[Infer] Revise and Adapt Triton Kernels for Spec-Dec (#5401)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * resolve conflicts for revising flash-attn * adapt kv cache copy kernel for spec-dec * fix seqlen-n kvcache copy kernel/tests * test kvcache copy - use torch.equal * add assertions * (trivial) comment outfeat/speculative-decoding
parent
d56c96334e
commit
d63c469f45
|
@ -11,7 +11,7 @@ if HAS_TRITON:
|
||||||
from .context_attn_unpad import context_attention_unpadded
|
from .context_attn_unpad import context_attention_unpadded
|
||||||
from .flash_decoding import flash_decoding_attention
|
from .flash_decoding import flash_decoding_attention
|
||||||
from .fused_rotary_embedding import fused_rotary_embedding
|
from .fused_rotary_embedding import fused_rotary_embedding
|
||||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||||
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
|
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from .rms_layernorm import rms_layernorm
|
from .rms_layernorm import rms_layernorm
|
||||||
from .rotary_cache_copy import get_xine_cache
|
from .rotary_cache_copy import get_xine_cache
|
||||||
|
@ -20,6 +20,7 @@ if HAS_TRITON:
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"context_attention_unpadded",
|
"context_attention_unpadded",
|
||||||
"flash_decoding_attention",
|
"flash_decoding_attention",
|
||||||
|
"copy_k_to_blocked_cache",
|
||||||
"copy_kv_to_blocked_cache",
|
"copy_kv_to_blocked_cache",
|
||||||
"softmax",
|
"softmax",
|
||||||
"rms_layernorm",
|
"rms_layernorm",
|
||||||
|
|
|
@ -9,13 +9,14 @@ import triton.language as tl
|
||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _flash_decoding_fwd_kernel(
|
def _flash_decoding_fwd_kernel(
|
||||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
Q, # [batch_size * q_len, head_num, head_dim]
|
||||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||||
kv_seq_len, # [batch_size]
|
kv_seq_len, # [batch_size]
|
||||||
|
q_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
stride_qt,
|
stride_qt,
|
||||||
stride_qh,
|
stride_qh,
|
||||||
|
@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel(
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
HEAD_DIM: tl.constexpr,
|
HEAD_DIM: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq_idx = tl.program_id(0)
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // q_len
|
||||||
if cur_seq_idx >= batch_size:
|
if cur_seq_idx >= batch_size:
|
||||||
return
|
return
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
block_start_kv = tl.program_id(2) # for splitting k/v
|
block_start_kv = tl.program_id(2) # for splitting k/v
|
||||||
|
|
||||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
|
||||||
|
|
||||||
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
|
||||||
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
|
||||||
# and then support calculating multiple kv cache blocks on an instance
|
# and then support calculating multiple kv cache blocks on an instance
|
||||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||||
|
# get the current (kv) sequence length
|
||||||
# get the current (kv) sequence length from provided context lengths tensor
|
|
||||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||||
|
|
||||||
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
|
||||||
q = tl.load(Q + offsets_q)
|
|
||||||
|
|
||||||
# block table for the current sequence
|
|
||||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
|
||||||
|
|
||||||
# actually current block table current block start idx
|
|
||||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
|
||||||
cur_bt_start_idx = block_start_kv
|
|
||||||
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
|
||||||
|
|
||||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||||
|
q = tl.load(Q + offsets_q)
|
||||||
|
# block table for the current sequence
|
||||||
|
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||||
|
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||||
|
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||||
|
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
|
||||||
cur_occupied_size = tl.where(
|
cur_occupied_size = tl.where(
|
||||||
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
|
||||||
)
|
)
|
||||||
tl.device_assert(cur_occupied_size >= 0)
|
tl.device_assert(cur_occupied_size >= 0)
|
||||||
|
|
||||||
|
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||||
|
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=KCache + offset_kvcache,
|
base=KCache + offset_kvcache,
|
||||||
shape=(cur_occupied_size, HEAD_DIM),
|
shape=(cur_occupied_size, HEAD_DIM),
|
||||||
|
@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel(
|
||||||
acc = acc / l
|
acc = acc / l
|
||||||
|
|
||||||
offsets_mid_o = (
|
offsets_mid_o = (
|
||||||
cur_seq_idx * stride_mid_ot
|
cur_token_idx * stride_mid_ot
|
||||||
+ cur_head_idx * stride_mid_oh
|
+ cur_head_idx * stride_mid_oh
|
||||||
+ block_start_kv * stride_mid_ob
|
+ block_start_kv * stride_mid_ob
|
||||||
+ offsets_dmodel * stride_mid_od
|
+ offsets_dmodel * stride_mid_od
|
||||||
)
|
)
|
||||||
tl.store(mid_o + offsets_mid_o, acc)
|
tl.store(mid_o + offsets_mid_o, acc)
|
||||||
offsets_mid_o_lse = (
|
offsets_mid_o_lse = (
|
||||||
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
|
||||||
)
|
)
|
||||||
# logsumexp L^(j) = m^(j) + log(l^(j))
|
# logsumexp L^(j) = m^(j) + log(l^(j))
|
||||||
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
|
||||||
|
@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||||
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
stride_mid_ot,
|
stride_mid_ot,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
|
@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
BLOCK_KV: tl.constexpr,
|
BLOCK_KV: tl.constexpr,
|
||||||
HEAD_DIM: tl.constexpr,
|
HEAD_DIM: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq_idx = tl.program_id(0)
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // q_len
|
||||||
if cur_seq_idx >= batch_size:
|
if cur_seq_idx >= batch_size:
|
||||||
return
|
return
|
||||||
cur_head_idx = tl.program_id(1)
|
cur_head_idx = tl.program_id(1)
|
||||||
|
@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
l = 0.0 # sum exp
|
l = 0.0 # sum exp
|
||||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
|
||||||
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh
|
offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
|
||||||
for block_i in range(0, kv_split_num, 1):
|
for block_i in range(0, kv_split_num, 1):
|
||||||
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
|
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
|
||||||
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
|
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
|
||||||
|
@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
m_i = m_ij
|
m_i = m_ij
|
||||||
|
|
||||||
acc = acc / l
|
acc = acc / l
|
||||||
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
||||||
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -199,12 +195,14 @@ def flash_decoding_attention(
|
||||||
mid_output_lse: torch.Tensor = None,
|
mid_output_lse: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
kv_group_num: int = 1,
|
kv_group_num: int = 1,
|
||||||
|
q_len: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
|
||||||
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
kv_seq_len (torch.Tensor): [batch_size]
|
kv_seq_len (torch.Tensor): [batch_size]
|
||||||
|
@ -212,19 +210,25 @@ def flash_decoding_attention(
|
||||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||||
output (torch.Tensor): [bsz, num_heads * head_dim]
|
output (torch.Tensor): [bsz, num_heads * head_dim]
|
||||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim]
|
||||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
|
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
|
||||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
|
q_len > 1 only for verification process in speculative-decoding.
|
||||||
block_size (int): Size of each block in the blocked key/value cache.
|
block_size (int): Size of each block in the blocked key/value cache.
|
||||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||||
|
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||||
|
Defaults to 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor with shape [bsz, num_heads * head_dim]
|
Output tensor with shape [bsz * q_len, num_heads * head_dim]
|
||||||
"""
|
"""
|
||||||
q = q.squeeze() if q.dim() == 4 else q
|
q = q.squeeze() if q.dim() == 4 else q
|
||||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||||
bsz, num_heads, head_dim = q.shape
|
n_tokens, num_heads, head_dim = q.shape
|
||||||
|
assert n_tokens % q_len == 0, "Invalid q_len"
|
||||||
|
bsz = n_tokens // q_len
|
||||||
|
|
||||||
assert head_dim in {32, 64, 128, 256}
|
assert head_dim in {32, 64, 128, 256}
|
||||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
|
@ -247,22 +251,31 @@ def flash_decoding_attention(
|
||||||
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
||||||
# For compatibility (TODO revise modeling in future)
|
# For compatibility (TODO revise modeling in future)
|
||||||
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
||||||
mid_output = (
|
|
||||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
if mid_output is None:
|
||||||
if mid_output is None
|
mid_output = torch.empty(
|
||||||
else mid_output
|
(bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
|
||||||
)
|
|
||||||
mid_output_lse = (
|
|
||||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
|
||||||
if mid_output_lse is None
|
|
||||||
else mid_output_lse
|
|
||||||
)
|
)
|
||||||
|
if mid_output_lse is None:
|
||||||
|
mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||||
|
if output is None:
|
||||||
|
# A hack to prevent `view` operation in modeling
|
||||||
|
output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num
|
||||||
|
), "Incompatible kv split number of intermediate output tensors"
|
||||||
|
assert (
|
||||||
|
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens
|
||||||
|
), f"Incompatible first dimension of output tensors"
|
||||||
|
|
||||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
grid = (
|
||||||
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
triton.next_power_of_2(bsz * q_len),
|
||||||
|
num_heads,
|
||||||
|
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||||
|
)
|
||||||
_flash_decoding_fwd_kernel[grid](
|
_flash_decoding_fwd_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k_cache,
|
k_cache,
|
||||||
|
@ -271,6 +284,7 @@ def flash_decoding_attention(
|
||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
bsz,
|
bsz,
|
||||||
q.stride(0),
|
q.stride(0),
|
||||||
q.stride(1),
|
q.stride(1),
|
||||||
|
@ -295,13 +309,13 @@ def flash_decoding_attention(
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
|
||||||
|
|
||||||
_flash_decoding_fwd_reduce_kernel[grid](
|
_flash_decoding_fwd_reduce_kernel[grid](
|
||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
output,
|
output,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
q_len,
|
||||||
bsz,
|
bsz,
|
||||||
mid_output.stride(0),
|
mid_output.stride(0),
|
||||||
mid_output.stride(1),
|
mid_output.stride(1),
|
||||||
|
|
|
@ -3,6 +3,50 @@ import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# Triton 2.1.0
|
||||||
|
@triton.jit
|
||||||
|
def _copy_to_kcache_seqlen_n_kernel(
|
||||||
|
KV, # K or V
|
||||||
|
KVCache, # KCache or VCache
|
||||||
|
BLOCK_TABLES,
|
||||||
|
context_lengths,
|
||||||
|
stride_kt,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_cacheb,
|
||||||
|
stride_cacheh,
|
||||||
|
stride_cachebs,
|
||||||
|
stride_cached,
|
||||||
|
stride_bts,
|
||||||
|
stride_btb,
|
||||||
|
block_size,
|
||||||
|
n,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_token_idx = tl.program_id(0)
|
||||||
|
cur_seq_idx = cur_token_idx // n
|
||||||
|
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
|
||||||
|
# cur_token_shift = cur_token_idx - n * cur_seq_idx
|
||||||
|
cur_kv_head_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
|
||||||
|
last_bt_block_idx = past_kv_seq_len // block_size
|
||||||
|
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||||
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||||
|
offset_last_block = past_kv_seq_len % block_size
|
||||||
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
|
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
|
kv = tl.load(KV + offsets_kv)
|
||||||
|
offsets_kvcache = (
|
||||||
|
block_id * stride_cacheb
|
||||||
|
+ cur_kv_head_idx * stride_cacheh
|
||||||
|
+ offset_last_block * stride_cachebs
|
||||||
|
+ offsets_dmodel * stride_cached
|
||||||
|
)
|
||||||
|
tl.store(KVCache + offsets_kvcache, kv)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
# Triton 2.1.0
|
# Triton 2.1.0
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _copy_to_kvcache_seqlen1_kernel(
|
def _copy_to_kvcache_seqlen1_kernel(
|
||||||
|
@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||||
offsets_in_last_block = past_kv_seq_len % block_size
|
offsets_in_last_block = past_kv_seq_len % block_size
|
||||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||||
|
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
|
||||||
|
|
||||||
k = tl.load(K + offsets_kv)
|
k = tl.load(K + offsets_k)
|
||||||
v = tl.load(V + offsets_kv)
|
v = tl.load(V + offsets_v)
|
||||||
|
|
||||||
offsets_kcache = (
|
offsets_kcache = (
|
||||||
block_id * stride_cachekb
|
block_id * stride_cachekb
|
||||||
|
@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def copy_k_to_blocked_cache(
|
||||||
|
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||||
|
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
|
||||||
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||||
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||||
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||||
|
n (int): Number of tokens to copy for each sequence. Default to 1.
|
||||||
|
"""
|
||||||
|
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||||
|
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||||
|
|
||||||
|
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
|
||||||
|
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
|
||||||
|
bsz, num_kv_heads, head_dim = k.shape
|
||||||
|
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
|
||||||
|
if n > 1:
|
||||||
|
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
|
||||||
|
bsz = bsz // n
|
||||||
|
|
||||||
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||||
|
f"Got incompatible batch size (number of seqs):\n"
|
||||||
|
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||||
|
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modify if the shape of kv cahce is changed.
|
||||||
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
|
num_warps = 8 if head_dim > 128 else 4
|
||||||
|
|
||||||
|
grid = (bsz * n, num_kv_heads)
|
||||||
|
_copy_to_kcache_seqlen_n_kernel[grid](
|
||||||
|
k,
|
||||||
|
k_cache,
|
||||||
|
block_tables,
|
||||||
|
kv_lengths,
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
block_tables.stride(0),
|
||||||
|
block_tables.stride(1),
|
||||||
|
block_size,
|
||||||
|
n=n,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def copy_kv_to_blocked_cache(
|
def copy_kv_to_blocked_cache(
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
|
|
|
@ -19,12 +19,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||||
|
|
||||||
|
|
||||||
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"):
|
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
|
||||||
padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device)
|
padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cur_seq_len = kv_lengths[i].item()
|
cur_seq_len = kv_lengths[i].item()
|
||||||
assert cur_seq_len <= kv_seq_len
|
assert cur_seq_len <= kv_len
|
||||||
padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf")
|
padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,12 +33,12 @@ def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, de
|
||||||
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
|
||||||
def torch_attn_ref(
|
def torch_attn_ref(
|
||||||
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
|
||||||
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
||||||
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
|
v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim]
|
||||||
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
|
attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len]
|
||||||
bsz: int,
|
bsz: int,
|
||||||
seq_len: int,
|
q_len: int,
|
||||||
kv_seq_len: int,
|
kv_len: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
@ -54,22 +54,22 @@ def torch_attn_ref(
|
||||||
|
|
||||||
qk = torch.matmul(q, k.transpose(2, 3))
|
qk = torch.matmul(q, k.transpose(2, 3))
|
||||||
attn_scores = qk / (head_dim**0.5)
|
attn_scores = qk / (head_dim**0.5)
|
||||||
assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores"
|
|
||||||
|
assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
|
||||||
# for left-side padding
|
# for left-side padding
|
||||||
if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len):
|
if attention_mask.size() != (bsz, 1, q_len, kv_len):
|
||||||
raise ValueError(
|
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}")
|
||||||
f"Attention mask should be of size {(bsz, 1, seq_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_scores = attn_scores + attention_mask
|
attn_scores = attn_scores + attention_mask
|
||||||
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
|
||||||
out = torch.matmul(attn_weights, v)
|
out = torch.matmul(attn_weights, v)
|
||||||
if out.size() != (bsz, num_heads, seq_len, head_dim):
|
if out.size() != (bsz, num_heads, q_len, head_dim):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
|
||||||
)
|
)
|
||||||
out = out.transpose(1, 2).contiguous()
|
out = out.transpose(1, 2).contiguous()
|
||||||
out = out.squeeze(1)
|
out = out.view(-1, out.size(-2), out.size(-1))
|
||||||
|
# out [bsz * q_len, num_heads, head_dim]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ except ImportError:
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
Q_LEN = 1
|
|
||||||
HEAD_DIM = 128
|
HEAD_DIM = 128
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,6 +63,7 @@ def prepare_data(
|
||||||
@pytest.mark.parametrize("num_attn_heads", [16])
|
@pytest.mark.parametrize("num_attn_heads", [16])
|
||||||
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
@pytest.mark.parametrize("kv_group_num", [1, 2, 16])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
|
@pytest.mark.parametrize("q_len", [1, 5])
|
||||||
def test_flash_decoding(
|
def test_flash_decoding(
|
||||||
bsz: int,
|
bsz: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
@ -71,6 +71,7 @@ def test_flash_decoding(
|
||||||
num_attn_heads: int,
|
num_attn_heads: int,
|
||||||
kv_group_num: int,
|
kv_group_num: int,
|
||||||
same_context_len: bool,
|
same_context_len: bool,
|
||||||
|
q_len: int,
|
||||||
):
|
):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -82,47 +83,57 @@ def test_flash_decoding(
|
||||||
max_seq_len = block_size * max_num_blocks_per_seq
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
||||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device
|
||||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
|
||||||
)
|
)
|
||||||
|
# The maximum sequence length in the batch (if context lengths randomly generated)
|
||||||
|
max_kv_len_in_b = kv_lengths.max().item()
|
||||||
|
|
||||||
|
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
|
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b)
|
||||||
|
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device)
|
||||||
|
out_torch = torch_attn_ref(
|
||||||
|
q, k_torch, v_torch, torch_padding_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||||
|
)
|
||||||
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
# The maximum sequence length in the batch (if context lengths randomly generated)
|
|
||||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
|
||||||
# The maximum block length splitted on kv should be the kv cache block size
|
# The maximum block length splitted on kv should be the kv cache block size
|
||||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
|
||||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||||
mid_output = torch.empty(
|
mid_output = torch.empty(
|
||||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||||
|
)
|
||||||
|
mid_output_lse = torch.empty(
|
||||||
|
size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device
|
||||||
)
|
)
|
||||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
|
||||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||||
|
# Here we use different methods to hide the q_len dimension,
|
||||||
|
# refer to attention forward function in modeling.
|
||||||
|
if q_len > 1:
|
||||||
|
q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim]
|
||||||
|
q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim]
|
||||||
|
else:
|
||||||
|
q = q.squeeze(2)
|
||||||
|
assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM)
|
||||||
|
|
||||||
out_triton = flash_decoding_attention(
|
out_triton = flash_decoding_attention(
|
||||||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
q,
|
||||||
# refer to attention forward in modeling.
|
|
||||||
q.squeeze(2),
|
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
kv_seq_lengths,
|
kv_lengths,
|
||||||
block_tables,
|
block_tables,
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len_in_b,
|
max_kv_len_in_b,
|
||||||
output,
|
output,
|
||||||
mid_output,
|
mid_output,
|
||||||
mid_output_lse,
|
mid_output_lse,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
) # [bsz, 1, num_heads, head_dim]
|
q_len=q_len,
|
||||||
|
) # [bsz * q_len, num_heads, head_dim]
|
||||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
|
||||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b)
|
|
||||||
torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device)
|
|
||||||
out_torch = torch_attn_ref(
|
|
||||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
|
||||||
)
|
|
||||||
|
|
||||||
assert out_torch.shape == out_triton.shape
|
assert out_torch.shape == out_triton.shape
|
||||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||||
|
|
|
@ -2,7 +2,8 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
||||||
|
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||||
|
|
||||||
|
@ -16,7 +17,7 @@ except ImportError:
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
HEAD_DIM = 128
|
HEAD_DIM = 32
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
|
@ -27,15 +28,16 @@ def prepare_data(
|
||||||
max_num_blocks_per_seq,
|
max_num_blocks_per_seq,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
|
n,
|
||||||
device,
|
device,
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
):
|
):
|
||||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||||
# (not incorporating the current input whose seq len is 1)
|
|
||||||
past_kv_seq_lengths = (
|
past_kv_seq_lengths = (
|
||||||
torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||||
if same_context_len
|
if same_context_len
|
||||||
else torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
|
||||||
)
|
)
|
||||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||||
|
|
||||||
|
@ -48,14 +50,14 @@ def prepare_data(
|
||||||
)
|
)
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
|
|
||||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||||
# mock allocating blocks for the new k/v and update block tables
|
# mock allocating blocks for the new k/v and update block tables
|
||||||
|
for _ in range(n):
|
||||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
past_kv_seq_lengths += 1
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
|
||||||
|
|
||||||
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
|
return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
|
@ -64,12 +66,9 @@ def prepare_data(
|
||||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
|
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||||
def test_copy_kv_to_caches(
|
def test_copy_kv_to_caches(
|
||||||
bsz: int,
|
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
|
||||||
block_size: int,
|
|
||||||
max_num_blocks_per_seq: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
same_context_len: bool,
|
|
||||||
):
|
):
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -88,21 +87,45 @@ def test_copy_kv_to_caches(
|
||||||
max_num_blocks_per_seq,
|
max_num_blocks_per_seq,
|
||||||
same_context_len,
|
same_context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
|
n_tokens,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
# k_cache_torch = k_cache.clone().detach()
|
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
|
||||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
|
||||||
|
k_cache_copy = k_cache.detach().clone()
|
||||||
|
past_kv_seq_lengths = kv_seq_lengths - n_tokens
|
||||||
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_lengths % block_size
|
||||||
|
|
||||||
|
# Copy k (or v) to k (or v) cache
|
||||||
|
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
|
||||||
|
# Reshape target k from k cache to compare if matching with original tensor
|
||||||
|
# Mainly to handle cases of n_tokens > 1
|
||||||
|
k_target = []
|
||||||
|
for i in range(bsz):
|
||||||
|
block_table = block_tables[i]
|
||||||
|
curr_kv_len = past_kv_seq_lengths[i].item()
|
||||||
|
offset = offsets_in_block[i].item()
|
||||||
|
tokens_left = n_tokens
|
||||||
|
while tokens_left > 0:
|
||||||
|
tokens_to_fill = min(block_size - offset, tokens_left)
|
||||||
|
curr_block_id = block_table[curr_kv_len // block_size]
|
||||||
|
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
||||||
|
curr_kv_len += tokens_to_fill
|
||||||
|
tokens_left -= tokens_to_fill
|
||||||
|
offset = 0
|
||||||
|
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
assert k_target.shape == k_source.shape
|
||||||
|
assert torch.equal(k_target, k_source)
|
||||||
|
|
||||||
|
if n_tokens == 1:
|
||||||
|
# Copy k and v to k/v caches
|
||||||
|
k_cache = k_cache_copy
|
||||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||||
|
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||||
past_kv_seq_len = kv_seq_lengths - 1
|
|
||||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
|
||||||
offsets_in_block = past_kv_seq_len % block_size
|
|
||||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
|
||||||
k_source = new_k.squeeze()
|
|
||||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
v_source = new_v.squeeze()
|
|
||||||
|
|
||||||
assert k_target.shape == k_source.shape
|
assert k_target.shape == k_source.shape
|
||||||
assert torch.equal(k_target, k_source)
|
assert torch.equal(k_target, k_source)
|
||||||
assert v_target.shape == v_source.shape
|
assert v_target.shape == v_source.shape
|
||||||
|
|
Loading…
Reference in New Issue