mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
296 lines
10 KiB
296 lines
10 KiB
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
# Triton 2.1.0 |
|
# supports two types of cache layouts |
|
# 1. [num_blocks, num_kv_heads, block_size, head_dim] |
|
# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x] |
|
@triton.jit |
|
def _copy_to_kcache_seqlen_n_kernel( |
|
K, # K or V |
|
KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] |
|
BLOCK_TABLES, |
|
seq_lengths, |
|
stride_kt, |
|
stride_kh, |
|
stride_kd, |
|
stride_kcb, |
|
stride_kch, |
|
stride_kcsplit_x, |
|
stride_kcs, |
|
stride_kcx, |
|
stride_bts, |
|
stride_btb, |
|
block_size, |
|
n_tokens, |
|
HEAD_DIM: tl.constexpr, |
|
KCACHE_X: tl.constexpr, |
|
): |
|
# `n_tokens` is used to specify the number of tokens to copy for each sequence |
|
# When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid, |
|
# `seq_lengths` must be the lengths of sequences counting the number of tokens to copy |
|
# E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9] |
|
# for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14]. |
|
# When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage |
|
cur_token_idx = tl.program_id(0) |
|
cur_seq_idx = cur_token_idx // n_tokens |
|
# `cur_token_shift` is only valid and functional when `n_tokens` > 1 |
|
cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1)) |
|
cur_kv_head_idx = tl.program_id(1) |
|
split_x_idx = tl.program_id(2) |
|
|
|
past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift |
|
last_bt_block_idx = past_kv_seq_len // block_size |
|
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts |
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) |
|
offset_last_block = past_kv_seq_len % block_size |
|
offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X) |
|
offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd |
|
k = tl.load(K + offsets_k) |
|
offsets_kcache = ( |
|
block_id * stride_kcb |
|
+ cur_kv_head_idx * stride_kch |
|
+ split_x_idx * stride_kcsplit_x |
|
+ offset_last_block * stride_kcs |
|
+ tl.arange(0, KCACHE_X) |
|
) |
|
tl.store(KCache + offsets_kcache, k) |
|
return |
|
|
|
|
|
# Triton 2.1.0 |
|
@triton.jit |
|
def _copy_to_kvcache_seqlen1_kernel( |
|
K, |
|
V, |
|
KCache, |
|
VCache, |
|
BLOCK_TABLES, |
|
context_lengths, |
|
stride_kt, |
|
stride_kh, |
|
stride_kd, |
|
stride_vt, |
|
stride_vh, |
|
stride_vd, |
|
stride_kcb, |
|
stride_kch, |
|
stride_kcsplit_x, |
|
stride_kcs, |
|
stride_kcd, |
|
stride_vcb, |
|
stride_vch, |
|
stride_vcs, |
|
stride_vcd, |
|
stride_bts, |
|
stride_btb, |
|
block_size, |
|
HEAD_DIM: tl.constexpr, |
|
KCACHE_X: tl.constexpr, |
|
): |
|
cur_seq_idx = tl.program_id(0) |
|
cur_kv_head_idx = tl.program_id(1) |
|
|
|
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1 |
|
last_bt_block_idx = past_kv_seq_len // block_size |
|
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts |
|
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) |
|
offsets_in_last_block = past_kv_seq_len % block_size |
|
|
|
range_x = tl.arange(0, KCACHE_X) |
|
offsets_dmodel_x_partition = tl.arange(0, KCACHE_X) |
|
|
|
for split_x in tl.static_range(HEAD_DIM // KCACHE_X): |
|
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) |
|
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd |
|
k = tl.load(K + offsets_k) |
|
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd |
|
v = tl.load(V + offsets_v) |
|
|
|
offsets_kcache = ( |
|
block_id * stride_kcb |
|
+ cur_kv_head_idx * stride_kch |
|
+ split_x * stride_kcsplit_x |
|
+ offsets_in_last_block * stride_kcs |
|
+ range_x |
|
) |
|
tl.store(KCache + offsets_kcache, k) |
|
offsets_vcache = ( |
|
block_id * stride_vcb |
|
+ cur_kv_head_idx * stride_vch |
|
+ offsets_in_last_block * stride_vcs |
|
+ offsets_dmodel_x_partition * stride_vcd |
|
) |
|
tl.store(VCache + offsets_vcache, v) |
|
return |
|
|
|
|
|
def copy_k_to_blocked_cache( |
|
k: torch.Tensor, |
|
k_cache: torch.Tensor, |
|
kv_lengths: torch.Tensor, |
|
block_tables: torch.Tensor, |
|
n: int = 1, |
|
use_new_kcache_layout: bool = False, |
|
): |
|
""" |
|
Copy keys or values to the blocked key/value cache during decoding stage. |
|
|
|
Args: |
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. |
|
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n |
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. |
|
new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] |
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. |
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. |
|
n (int): Number of tokens to copy for each sequence. Default to 1. |
|
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. |
|
""" |
|
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." |
|
if k.dim() == 4: |
|
k = k.reshape(-1, k.size(-2), k.size(-1)) |
|
k_shape = k.shape |
|
bsz, num_kv_heads, head_dim = k_shape |
|
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim] |
|
if n > 1: |
|
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied" |
|
bsz = bsz // n |
|
|
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( |
|
f"Got incompatible batch size (number of seqs):\n" |
|
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " |
|
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" |
|
) |
|
|
|
k_cache_shape = k_cache.shape |
|
# Modify if the shape of kv cahce is changed. |
|
block_size = k_cache_shape[-2] |
|
|
|
x = head_dim |
|
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) |
|
if use_new_kcache_layout: |
|
# when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] |
|
assert ( |
|
len(k_cache_shape) == 5 |
|
and k_cache_shape[1] == k_shape[1] |
|
and k_cache_shape[2] * k_cache_shape[4] == k_shape[2] |
|
), f"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}" |
|
x = k_cache.size(-1) |
|
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] |
|
|
|
num_warps = 8 if head_dim > 128 else 4 |
|
grid = (bsz * n, num_kv_heads, head_dim // x) |
|
_copy_to_kcache_seqlen_n_kernel[grid]( |
|
k, |
|
k_cache, |
|
block_tables, |
|
kv_lengths, |
|
k.stride(0), |
|
k.stride(1), |
|
k.stride(2), |
|
k_cache.stride(0), |
|
k_cache.stride(1), |
|
stride_kcsplit_x, |
|
stride_kcs, |
|
stride_kcd, |
|
block_tables.stride(0), |
|
block_tables.stride(1), |
|
block_size, |
|
n_tokens=n, |
|
HEAD_DIM=head_dim, |
|
KCACHE_X=x, |
|
num_warps=num_warps, |
|
) |
|
|
|
|
|
def copy_kv_to_blocked_cache( |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
k_cache: torch.Tensor, |
|
v_cache: torch.Tensor, |
|
kv_lengths: torch.Tensor, |
|
block_tables: torch.Tensor, |
|
use_new_kcache_layout: bool = False, |
|
): |
|
""" |
|
Copy keys or values to the blocked key/value cache during decoding stage. |
|
|
|
Args: |
|
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1. |
|
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1. |
|
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache. |
|
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache. |
|
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. |
|
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. |
|
use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. |
|
""" |
|
k_cache_shape = k_cache.shape |
|
v_cache_shape = v_cache.shape |
|
|
|
if use_new_kcache_layout: |
|
assert ( |
|
len(k_cache_shape) == 5 |
|
and k_cache_shape[1] == v_cache_shape[1] |
|
and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3] |
|
), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" |
|
else: |
|
assert k.size(-1) == k_cache_shape[-1], "Incompatible head dim" |
|
assert ( |
|
k_cache_shape == v_cache_shape |
|
), f"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" |
|
assert v.size(-1) == v_cache_shape[-1], "Incompatible head dim" |
|
|
|
k = k.squeeze(1) if k.dim() == 4 else k |
|
assert k.dim() == 3, f"Incompatible k dim {k.dim()}" |
|
v = v.squeeze(1) if v.dim() == 4 else v |
|
assert v.dim() == 3, f"Incompatible v dim {v.dim()}" |
|
|
|
bsz, num_kv_heads, head_dim = k.shape |
|
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( |
|
f"Got incompatible batch size (number of seqs):\n" |
|
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " |
|
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" |
|
) |
|
|
|
# Modify if the shape of kv cahce is changed. |
|
block_size = k_cache.size(-2) |
|
|
|
x = head_dim |
|
stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) |
|
if use_new_kcache_layout: |
|
x = k_cache.size(-1) |
|
stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] |
|
|
|
num_warps = 8 if head_dim > 128 else 4 |
|
grid = (bsz, num_kv_heads) |
|
_copy_to_kvcache_seqlen1_kernel[grid]( |
|
k, |
|
v, |
|
k_cache, |
|
v_cache, |
|
block_tables, |
|
kv_lengths, |
|
k.stride(0), |
|
k.stride(1), |
|
k.stride(2), |
|
v.stride(0), |
|
v.stride(1), |
|
v.stride(2), |
|
k_cache.stride(0), |
|
k_cache.stride(1), |
|
stride_kcsplit_x, |
|
stride_kcs, |
|
stride_kcd, |
|
v_cache.stride(0), |
|
v_cache.stride(1), |
|
v_cache.stride(2), |
|
v_cache.stride(3), |
|
block_tables.stride(0), |
|
block_tables.stride(1), |
|
block_size, |
|
HEAD_DIM=head_dim, |
|
KCACHE_X=x, |
|
num_warps=num_warps, |
|
)
|
|
|