mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattentionpull/5680/head
parent
5f00002e43
commit
5cd75ce4c7
|
@ -90,9 +90,18 @@ class KVCacheManager:
|
|||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||
if config.use_cuda_kernel:
|
||||
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
|
||||
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
|
||||
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(
|
||||
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
|
||||
)
|
||||
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
|
||||
else:
|
||||
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
* self.num_layers
|
||||
|
@ -479,7 +488,9 @@ class KVCacheManager:
|
|||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
||||
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _init_device_caches(
|
||||
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the physical cache on the device.
|
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||
|
@ -488,6 +499,6 @@ class KVCacheManager:
|
|||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
|
||||
return k_cache, v_cache
|
||||
|
|
|
@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.mid_output_lse,
|
||||
self.alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
|
@ -355,21 +371,21 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -98,15 +98,8 @@ def llama_model_forward(
|
|||
"""
|
||||
block_tables = inputmetadata.block_tables
|
||||
sequence_lengths = inputmetadata.sequence_lengths
|
||||
batch_size = inputmetadata.batch_size
|
||||
kv_seq_len = inputmetadata.kv_seq_len
|
||||
|
||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||
# selection should be conducted.
|
||||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
|
||||
# during speculative-decoding (`q_len > 1`)
|
||||
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
|
||||
|
@ -575,6 +568,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
|||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
@ -592,20 +586,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
|||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
# inference_ops.flash_decoding_attention(
|
||||
# output_tensor,
|
||||
# query_states,
|
||||
# k_cache,
|
||||
# v_cache,
|
||||
# sequence_lengths,
|
||||
# block_tables,
|
||||
# block_size,
|
||||
# kv_seq_len,
|
||||
# fd_inter_tensor.mid_output,
|
||||
# fd_inter_tensor.mid_output_lse,
|
||||
# sm_scale,
|
||||
# )
|
||||
# attn_output = output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.mid_output_lse,
|
||||
None,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
@ -627,21 +622,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
|||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -20,7 +20,7 @@ inference_ops = InferenceOpsLoader().load()
|
|||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
|
||||
x_vals=[2**i for i in range(3, 8)],
|
||||
x_vals=[2**i for i in range(2, 8)],
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm_paged_decoding_attention",
|
||||
|
@ -113,6 +113,8 @@ def benchmark_flash_decoding_attention(
|
|||
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
|
||||
sm_scale = 1.0 / (HEAD_SIZE**0.5)
|
||||
alibi_slopes = None
|
||||
kv_scale = 1.0
|
||||
|
||||
mid_output = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||||
|
@ -136,6 +138,7 @@ def benchmark_flash_decoding_attention(
|
|||
max_seq_len_across_batch,
|
||||
alibi_slopes,
|
||||
"auto",
|
||||
kv_scale,
|
||||
)
|
||||
elif provider == "triton_flash_decoding_attention":
|
||||
fn = lambda: flash_decoding_attention(
|
||||
|
@ -164,6 +167,7 @@ def benchmark_flash_decoding_attention(
|
|||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -2,7 +2,11 @@ import torch
|
|||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
@ -68,11 +72,17 @@ def benchmark_rotary_emb(
|
|||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
|
||||
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")
|
||||
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
_ = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
new_v = torch.randn_like(new_k)
|
||||
|
@ -94,12 +104,12 @@ def benchmark_rotary_emb(
|
|||
)
|
||||
elif provider == "no_fused_cuda_rotary_emb_func":
|
||||
fn = lambda: [
|
||||
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||
]
|
||||
elif provider == "fused_cuda_rotary_emb_func":
|
||||
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
|
|
@ -4,6 +4,7 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache
|
|||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
|
||||
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||
|
||||
try:
|
||||
|
@ -68,6 +69,9 @@ def benchmark_kvcache_copy(
|
|||
elif provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
elif provider == "cuda_copy_func":
|
||||
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
|
||||
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
|
||||
)
|
||||
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
|
||||
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
|
||||
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
|
|
|
@ -24,14 +24,15 @@ __global__ void context_kv_cache_memcpy_kernel(
|
|||
const int batch_size,
|
||||
const int block_table_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride
|
||||
const int64_t value_stride,
|
||||
const int x
|
||||
)
|
||||
{
|
||||
const int seq_token_id = blockIdx.x;
|
||||
const int seq_id = blockIdx.y;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
|
||||
|
||||
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
return ;
|
||||
}
|
||||
|
||||
|
@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel(
|
|||
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
|
||||
int head_id;
|
||||
int head_offset;
|
||||
int x_id;
|
||||
int x_offset;
|
||||
int64_t key_src_id;
|
||||
int64_t value_src_id;
|
||||
int64_t target_id;
|
||||
int64_t target_key_id;
|
||||
int64_t target_value_id;
|
||||
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
x_id = head_offset / x;
|
||||
x_offset = head_offset % x;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
|
||||
}
|
||||
|
||||
// tail process
|
||||
|
@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel(
|
|||
for (; i < hidden_size; ++i ) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
x_id = head_offset / x;
|
||||
x_offset = head_offset % x;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,7 +99,7 @@ template<typename T, typename CacheT>
|
|||
void apply_context_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
|
@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy(
|
|||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
int batch_size = block_tables.size(0);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
|
@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy(
|
|||
batch_size, \
|
||||
block_table_stride, \
|
||||
key_stride, \
|
||||
value_stride \
|
||||
value_stride, \
|
||||
x \
|
||||
); \
|
||||
} while(0)
|
||||
|
||||
|
@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy(
|
|||
void context_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
|
|
|
@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
const int block_size,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride,
|
||||
const int block_table_stride
|
||||
const int block_table_stride,
|
||||
const int x
|
||||
)
|
||||
{
|
||||
const int seq_id = blockIdx.x;
|
||||
|
@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int x_id = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
const int64_t target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
|
||||
}
|
||||
|
||||
if (!Aligned) {
|
||||
for (; i < hidden_size; ++i ) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int x_id = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
const int64_t target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
key_cache[target_key_id] = key[key_src_id];
|
||||
value_cache[target_value_id] = value[value_src_id];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,7 +84,7 @@ template<typename scalar_t>
|
|||
void apply_decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
|
@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy(
|
|||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
|
@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
|
|||
block_size, \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
block_table_stride \
|
||||
block_table_stride, \
|
||||
x \
|
||||
); \
|
||||
} while(0)
|
||||
|
||||
|
@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
|
|||
void decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
|
|
|
@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const int* __restrict__ context_lens, // [num_tokens]
|
||||
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int max_seq_len,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
|
@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
const int thread_group_offset = lane % NUM_THREADS_PER_X;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
|
||||
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
|
||||
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
|
@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
|
|||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
context_lens.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
alibi_slopes_ptr, \
|
||||
max_context_len, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
|
@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
|
|||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int max_context_len,
|
||||
float scale) {
|
||||
float scale,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_tokens = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
|
@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
|
|||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
||||
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
dim3 grid(num_heads, num_tokens, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
|
@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
|
|||
context_lens, \
|
||||
block_tables, \
|
||||
max_context_len, \
|
||||
scale);
|
||||
scale, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
|
@ -367,6 +377,7 @@ void flash_decoding_attention(
|
|||
int max_context_len,
|
||||
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float scale) {
|
||||
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||
const int64_t value_stride, const int token_id,
|
||||
const int block_table_stride, const int head_num, const int head_dim,
|
||||
const int kv_head_num, const int block_size, const int half_head_dim,
|
||||
const int kv_head_num, const int block_size, const int x, const int half_head_dim,
|
||||
const int shard_block_size) {
|
||||
const int seq_len = sequence_lengths[token_id] - 1;
|
||||
const int block_offset = seq_len % block_size;
|
||||
|
@ -102,36 +102,40 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
return;
|
||||
}
|
||||
|
||||
scalar_t x[VecSize];
|
||||
scalar_t y[VecSize];
|
||||
scalar_t x0[VecSize];
|
||||
scalar_t x1[VecSize];
|
||||
scalar_t out_x[VecSize];
|
||||
scalar_t out_y[VecSize];
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
|
||||
i += blockDim.x * VecSize) {
|
||||
const int head_offset = i % half_head_dim;
|
||||
const int half_head_offset = i % half_head_dim;
|
||||
const int x_id = half_head_offset / x;
|
||||
const int x_offset = half_head_offset % x;
|
||||
const int shard_offset =
|
||||
(head_offset / shard_block_size) * shard_block_size +
|
||||
(head_offset % shard_block_size) / VecSize;
|
||||
(half_head_offset / shard_block_size) * shard_block_size +
|
||||
(half_head_offset % shard_block_size) / VecSize;
|
||||
const int64_t addr_offset =
|
||||
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
|
||||
const int64_t target_id = block_id * kv_head_num * head_dim * block_size +
|
||||
(i / half_head_dim) * block_size * head_dim +
|
||||
block_offset * head_dim + head_offset;
|
||||
token_id * key_stride + (i / half_head_dim) * head_dim + half_head_offset;
|
||||
const int64_t target_id = block_id * kv_head_num * head_dim * block_size
|
||||
+ (i / half_head_dim) * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
|
||||
copy_vector<scalar_t, VecSize>(x0, key + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size,
|
||||
out_y);
|
||||
}
|
||||
|
||||
|
@ -162,7 +166,8 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
|
|||
const int head_num,
|
||||
const int head_dim,
|
||||
const int kv_head_num,
|
||||
const int block_size
|
||||
const int block_size,
|
||||
const int x
|
||||
) {
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
|
@ -182,7 +187,7 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
|
|||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key and copy kv
|
||||
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
|
@ -220,6 +225,31 @@ __global__ void rotary_embedding_kernel(
|
|||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
}
|
||||
|
||||
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
|
||||
query.data_ptr<scalar_t>(), \
|
||||
key.data_ptr<scalar_t>(), \
|
||||
value.data_ptr<scalar_t>(), \
|
||||
cos.data_ptr<scalar_t>(), \
|
||||
sin.data_ptr<scalar_t>(), \
|
||||
key_cache.data_ptr<scalar_t>(), \
|
||||
value_cache.data_ptr<scalar_t>(), \
|
||||
sequence_lengths.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
query_stride, \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
shard_element_num / 2, \
|
||||
cos_stride, \
|
||||
sin_stride, \
|
||||
block_table_stride, \
|
||||
head_num, \
|
||||
head_dim, \
|
||||
kv_head_num, \
|
||||
block_size, \
|
||||
x); \
|
||||
|
||||
|
||||
template<typename scalar_t, bool high_precision>
|
||||
void apply_rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
|
@ -227,7 +257,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
|
@ -236,7 +266,8 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
int head_num = query.size(1);
|
||||
int head_dim = query.size(2);
|
||||
int kv_head_num = key.size(1);
|
||||
int block_size = key_cache.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int64_t query_stride = query.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
|
@ -261,80 +292,18 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
|
||||
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t);
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(1);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(2);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(4);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
|
@ -441,7 +410,7 @@ void rotary_embedding_and_cache_copy(
|
|||
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
void decode_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
|
@ -27,12 +28,13 @@ void rotary_embedding(
|
|||
bool high_precision);
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
|
@ -71,7 +73,7 @@ void flash_decoding_attention(
|
|||
torch::Tensor&
|
||||
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
|
|
|
@ -4,8 +4,10 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
|
|||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -73,6 +76,11 @@ def test_flash_decoding_attention(
|
|||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||||
device = get_current_device()
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||
)
|
||||
|
@ -91,6 +99,15 @@ def test_flash_decoding_attention(
|
|||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
|
||||
torch_padding_mask = torch_padding_mask + alibi_mask
|
||||
|
||||
if len(torch_padding_mask.size()) == 4:
|
||||
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
|
||||
else:
|
||||
torch_padding_mask = torch_padding_mask[:, -1:, :]
|
||||
|
||||
mid_output = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||||
)
|
||||
|
@ -146,8 +163,14 @@ def test_flash_decoding_attention(
|
|||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
# The alibi may introduce relatively large errors
|
||||
if use_alibi_slopes:
|
||||
rtol = 1e0
|
||||
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
|
@ -168,8 +191,9 @@ except ImportError:
|
|||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_vllm_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -199,6 +223,18 @@ def test_vllm_flash_decoding_attention(
|
|||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
|
||||
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
|
||||
torch_padding_mask = torch_padding_mask + alibi_mask
|
||||
|
||||
if len(torch_padding_mask.size()) == 4:
|
||||
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
|
||||
else:
|
||||
torch_padding_mask = torch_padding_mask[:, -1:, :]
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
if dtype == torch.float16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
@ -236,8 +272,6 @@ def test_vllm_flash_decoding_attention(
|
|||
HEAD_SIZE,
|
||||
)
|
||||
|
||||
alibi_slopes = None
|
||||
|
||||
vllm_ops.paged_attention_v1(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
|
@ -253,6 +287,11 @@ def test_vllm_flash_decoding_attention(
|
|||
"auto",
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
# The alibi may introduce relatively large errors
|
||||
if use_alibi_slopes:
|
||||
rtol = 1e0
|
||||
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
|
@ -277,5 +316,5 @@ if __name__ == "__main__":
|
|||
dtype,
|
||||
) in test_combinations:
|
||||
test_flash_decoding_attention(
|
||||
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
|
||||
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
|
||||
)
|
||||
|
|
|
@ -4,12 +4,40 @@ import torch.nn.functional as F
|
|||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2
|
||||
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
HEAD_DIM = 4
|
||||
HEAD_DIM = 72
|
||||
|
||||
|
||||
def prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
context_lengths,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
):
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
|
||||
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache = torch.zeros_like(k_cache_ref)
|
||||
v_cache = torch.zeros_like(v_cache_ref)
|
||||
|
||||
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
|
||||
|
||||
|
||||
def run_decode_copy_kv_to_caches(
|
||||
|
@ -24,32 +52,41 @@ def run_decode_copy_kv_to_caches(
|
|||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
n = 1
|
||||
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
same_context_len,
|
||||
max_seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||
|
||||
past_kv_seq_lengths = (
|
||||
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
if same_context_len
|
||||
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
|
||||
)
|
||||
|
||||
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
|
||||
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
|
||||
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
|
||||
)
|
||||
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
for _ in range(n):
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
past_kv_seq_lengths += 1
|
||||
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
|
||||
|
||||
past_kv_seq_len = past_kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
|
||||
k_source = new_k.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_target = k_target.reshape(v_target.shape)
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
assert k_target.shape == k_source.shape
|
||||
|
@ -77,22 +114,17 @@ def run_context_copy_kv_to_cache(
|
|||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache = torch.zeros_like(k_cache_ref)
|
||||
v_cache = torch.zeros_like(v_cache_ref)
|
||||
(
|
||||
key,
|
||||
value,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch,
|
||||
k_cache_ref,
|
||||
v_cache_ref,
|
||||
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
|
||||
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
|
||||
|
||||
|
@ -49,12 +49,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
|||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
|
||||
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
|
||||
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
|
||||
|
@ -97,9 +99,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
|||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
|
||||
k_source = new_k_copy.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||
k_target = k_target.reshape(v_target.shape)
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
|
||||
|
|
Loading…
Reference in New Issue