[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)

* Support FP16/BF16 Flash Attention 2

* fix bugs in test_kv_cache_memcpy.py

* add context_kv_cache_memcpy_kernel.cu

* rm typename MT

* add tail process

* add high_precision

* add high_precision to config.py

* rm unused code

* change the comment for the high_precision parameter

* update test_rotary_embdding_unpad.py

* fix vector_copy_utils.h

* add comment for self.high_precision when using float32
pull/5434/head^2
yuehuayingxueluo 2024-03-25 13:40:34 +08:00 committed by GitHub
parent 7ff42cc06d
commit 87079cffe8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 550 additions and 138 deletions

View File

@ -55,7 +55,7 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
# NOTE: arrange configs according to their importance and frequency of usage
@ -89,6 +89,7 @@ class InferenceConfig:
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
def __post_init__(self):
self._verify_config()
@ -108,6 +109,10 @@ class InferenceConfig:
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# skip using casting when the data type is float32
if self.dtype == torch.float32:
self.high_precision = False
# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()

View File

@ -56,6 +56,7 @@ class InferenceEngine:
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision
model = model.eval()
model = model.cuda()
model.to(self.dtype)
@ -297,6 +298,7 @@ class InferenceEngine:
batch,
self.k_cahce,
self.v_cache,
self.high_precision,
)
if self.inference_config.pad_input:

View File

@ -2,6 +2,7 @@
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
@ -30,24 +31,28 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
try:
HAS_TRITON = True
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
HAS_TRITON = False
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
batch: BatchBucket,
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
high_precision: bool = False,
):
"""This function will replace the forward function of LlamaForCausalLM.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@ -56,6 +61,7 @@ def llama_causal_lm_forward(
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
high_precision=high_precision,
)
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@ -63,16 +69,18 @@ def llama_causal_lm_forward(
def llama_model_forward(
self: LlamaModel,
batch: BatchBucket = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
batch: BatchBucket,
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
high_precision: bool = False,
):
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
@ -86,6 +94,11 @@ def llama_model_forward(
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
else:
cu_seqlens = None
hidden_states = self.embed_tokens(input_ids)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
@ -110,15 +123,17 @@ def llama_model_forward(
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=batch.is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
is_prompts=batch.is_prompts,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
if batch.is_prompts:
@ -135,38 +150,42 @@ def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
block_tables: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
sequence_lengths (torch.Tensor): Holding the sequence length of each sequence.
cos_sin (Tuple[torch.Tensor]): Holding cos and sin.
fd_inter_tensor (FDIntermTensors): Holding tensors used for
storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
@ -176,14 +195,16 @@ def llama_decoder_layer_forward(
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
is_prompts=is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
# Fully Connected
@ -277,43 +298,48 @@ class NopadLlamaAttention(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
block_tables: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
token_nums = hidden_states.size(0)
if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
else:
# fused qkv
token_nums = hidden_states.size(0)
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
@ -322,8 +348,26 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2)
if is_prompts:
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
@ -351,6 +395,7 @@ class NopadLlamaAttention(LlamaAttention):
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
else:
decoding_fused_rotary_embedding(
@ -436,6 +481,5 @@ class NopadLlamaMLP(LlamaMLP):
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.mm(tmp_out, self.down_proj_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)

View File

@ -136,6 +136,7 @@ def benchmark_inference(args):
data = data_gen(mbsz, args.seq_len)
if args.mode == "colossalai" or args.mode == "vllm":
data = data.tolist()
generation_config = GenerationConfig(

View File

@ -56,6 +56,23 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \

View File

@ -27,5 +27,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};
template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;
template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};
template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
};
} // namespace common
} // namespace colossalAI

View File

@ -0,0 +1,195 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ cu_seqlens,
const int* __restrict__ block_tables,
const int head_num,
const int head_dim,
const int block_size,
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}
const int block_offset = seq_token_id % block_size;
const int hidden_size = head_num * head_dim;
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
template<typename scalar_t>
void apply_context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(key);
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));
switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}

View File

@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel(
return ;
}
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
template<typename scalar_t>

View File

@ -1,14 +1,15 @@
// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
#include "../common/mp_type_traits.h"
template <typename scalar_t, int VecSize>
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) {
scalar_t x[VecSize];
@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute(
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset];
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset];
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
}
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy(
}
}
template <typename scalar_t, int VecSize>
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access(
const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
}
}
template <typename scalar_t, int VecSize>
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id,
@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute(
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset];
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset];
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
}
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute(
block_size, block_offset, head_dim, half_head_dim);
}
template<typename scalar_t, int VecSize>
template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
}
template<typename scalar_t, int VecSize>
template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel(
extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
}
template<typename scalar_t>
template<typename scalar_t, bool high_precision>
void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy(
int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy(
switch (vec_size) {
case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy(
);
break;
case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy(
);
break;
case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
@ -338,7 +341,7 @@ void apply_rotary_embedding_and_cache_copy(
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t>
template<typename scalar_t, bool high_precision>
void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
@ -355,6 +358,8 @@ void apply_rotary_embedding(
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
@ -373,7 +378,7 @@ void apply_rotary_embedding(
switch (vec_size) {
case 1:
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
@ -389,7 +394,7 @@ void apply_rotary_embedding(
);
break;
case 2:
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
@ -405,7 +410,7 @@ void apply_rotary_embedding(
);
break;
case 4:
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy(
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
at::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(),
"rotary_embedding_and_cache_copy",
apply_rotary_embedding_and_cache_copy<scalar_t>(
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
query,
key,
value,
@ -458,12 +465,14 @@ void rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim]
at::Tensor& sin, // [total_tokens, head_dim]
bool high_precision
){
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(),
"rotary_embedding",
apply_rotary_embedding<scalar_t>(
apply_rotary_embedding<scalar_t, high_precision>(
query,
key,
cos,

View File

@ -9,11 +9,22 @@ void decode_kv_cache_memcpy(
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch);
void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin); // [total_tokens, head_dim]
torch::Tensor& sin, // [total_tokens, head_dim]
bool high_precision);
void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy(
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);
torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the context stage.");
m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"performing Rotary Embedding-related calculations and KVCache Memcopy.");

View File

@ -11,6 +11,8 @@
#include <cfloat>
#include <limits>
#include "utils/vector_copy_utils.h"
namespace {
int log2_ceil(int value) {

View File

@ -11,16 +11,16 @@ template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src));
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src));
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<float4 *>(src + 4));
*(reinterpret_cast<const float4 *>(src + 4));
}
template <typename T, int VecSize>

View File

@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
for fname in [
"cuda/pybind/inference.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/context_kv_cache_memcpy_kernel.cu",
"cuda/fused_rotary_emb_and_cache_kernel.cu",
"cuda/activation_kernel.cu",
"cuda/rms_layernorm_kernel.cu",

View File

@ -1,8 +1,10 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
inference_ops = InferenceOpsLoader().load()
@ -10,12 +12,7 @@ inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_copy_kv_to_caches(
def run_decode_copy_kv_to_caches(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
@ -61,5 +58,65 @@ def test_copy_kv_to_caches(
assert torch.equal(v_target, v_source)
def run_context_copy_kv_to_cache(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
)
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_kv_cache_memcopy(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True)
test_kv_cache_memcopy(4, 32, 8, 16, True)

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
@ -10,11 +11,18 @@ from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
@ -54,17 +62,36 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
new_q_fp16 = new_q.clone()
new_k_fp16 = new_k.clone()
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
high_precision_q = new_q.to(torch.float32)
high_precision_k = new_k.to(torch.float32)
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
@ -74,18 +101,26 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
v_source = new_v.squeeze()
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
assert k_target.shape == k_source.shape
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if dtype == torch.float16:
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
rtol = 1e-3
atol = 1e-1
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
if __name__ == "__main__":
test_rotary_emb(16, 512, 4, 128, torch.float16)
test_rotary_emb(16, 64, 4, 128, torch.float16)