mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32pull/5434/head^2
parent
7ff42cc06d
commit
87079cffe8
|
@ -55,7 +55,7 @@ class InferenceConfig:
|
|||
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
|
@ -89,6 +89,7 @@ class InferenceConfig:
|
|||
pp_size: int = 1
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_config()
|
||||
|
@ -108,6 +109,10 @@ class InferenceConfig:
|
|||
self.dtype in _ALLOWED_DTYPES
|
||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
# skip using casting when the data type is float32
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check distributed
|
||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
|
|
|
@ -56,6 +56,7 @@ class InferenceEngine:
|
|||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.high_precision = inference_config.high_precision
|
||||
model = model.eval()
|
||||
model = model.cuda()
|
||||
model.to(self.dtype)
|
||||
|
@ -297,6 +298,7 @@ class InferenceEngine:
|
|||
batch,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
self.high_precision,
|
||||
)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
|
@ -30,24 +31,28 @@ inference_ops = InferenceOpsLoader().load()
|
|||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
HAS_TRITON = True
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
batch: BatchBucket,
|
||||
k_caches: List[torch.Tensor],
|
||||
v_caches: List[torch.Tensor],
|
||||
high_precision: bool = False,
|
||||
):
|
||||
"""This function will replace the forward function of LlamaForCausalLM.
|
||||
|
||||
Args:
|
||||
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
|
@ -56,6 +61,7 @@ def llama_causal_lm_forward(
|
|||
batch=batch,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
logits = torch.mm(hidden_states, self.lm_head.weight)
|
||||
return logits
|
||||
|
@ -63,16 +69,18 @@ def llama_causal_lm_forward(
|
|||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchBucket = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
batch: BatchBucket,
|
||||
k_caches: List[torch.Tensor],
|
||||
v_caches: List[torch.Tensor],
|
||||
high_precision: bool = False,
|
||||
):
|
||||
"""This function will replace the forward function of LlamaModel.
|
||||
|
||||
Args:
|
||||
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
batch (BatchInfo): It stores the necessary input information for this inference.
|
||||
k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
|
||||
v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
|
@ -86,6 +94,11 @@ def llama_model_forward(
|
|||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
||||
|
@ -110,15 +123,17 @@ def llama_model_forward(
|
|||
block_tables=block_tables,
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
is_prompts=batch.is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
is_prompts=batch.is_prompts,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
|
@ -135,38 +150,42 @@ def llama_decoder_layer_forward(
|
|||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
block_tables: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id.
|
||||
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
sequence_lengths (torch.Tensor): Holding the sequence length of each sequence.
|
||||
cos_sin (Tuple[torch.Tensor]): Holding cos and sin.
|
||||
fd_inter_tensor (FDIntermTensors): Holding tensors used for
|
||||
storing intermediate values in flash-decoding.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||
|
@ -176,14 +195,16 @@ def llama_decoder_layer_forward(
|
|||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
is_prompts=is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
is_prompts=is_prompts,
|
||||
kv_seq_len=kv_seq_len,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
high_precision=high_precision,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
|
@ -277,43 +298,48 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
block_tables: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
cos_sin: Tuple[torch.Tensor],
|
||||
fd_inter_tensor: FDIntermTensors,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
high_precision: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id.
|
||||
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
storing intermediate values in flash-decoding.
|
||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
token_nums = hidden_states.size(0)
|
||||
|
||||
if self.num_heads != self.num_key_value_heads:
|
||||
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
else:
|
||||
# fused qkv
|
||||
token_nums = hidden_states.size(0)
|
||||
hidden_states = hidden_states.expand(3, -1, -1)
|
||||
query_states, key_states, value_states = (
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
|
@ -322,23 +348,41 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
|
@ -351,6 +395,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
|
@ -436,6 +481,5 @@ class NopadLlamaMLP(LlamaMLP):
|
|||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
||||
tmp_out = act_out * gate_up_proj_out[1]
|
||||
return torch.mm(tmp_out, self.down_proj_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
|
|
|
@ -136,7 +136,8 @@ def benchmark_inference(args):
|
|||
|
||||
data = data_gen(mbsz, args.seq_len)
|
||||
|
||||
data = data.tolist()
|
||||
if args.mode == "colossalai" or args.mode == "vllm":
|
||||
data = data.tolist()
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
|
|
@ -56,6 +56,23 @@
|
|||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||
TYPE, NAME, ...) \
|
||||
switch (HIGH_PRECISION) { \
|
||||
case false: { \
|
||||
const bool high_precision = false; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
case true: { \
|
||||
const bool high_precision = true; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
|
|
|
@ -27,5 +27,18 @@ struct MPTypeTrait<at::BFloat16> {
|
|||
using Type = float;
|
||||
};
|
||||
|
||||
template <bool high_precision, typename scalar_t>
|
||||
struct ScalarTypeTrait;
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<true, T> {
|
||||
using Type = typename MPTypeTrait<T>::Type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<false, T> {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
__global__ void context_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ cu_seqlens,
|
||||
const int* __restrict__ block_tables,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int block_size,
|
||||
const int batch_size,
|
||||
const int block_table_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride
|
||||
)
|
||||
{
|
||||
const int seq_token_id = blockIdx.x;
|
||||
const int seq_id = blockIdx.y;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
|
||||
|
||||
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
return ;
|
||||
}
|
||||
|
||||
const int block_offset = seq_token_id % block_size;
|
||||
const int hidden_size = head_num * head_dim;
|
||||
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
|
||||
int head_id;
|
||||
int head_offset;
|
||||
int64_t key_src_id;
|
||||
int64_t value_src_id;
|
||||
int64_t target_id;
|
||||
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
}
|
||||
|
||||
// tail process
|
||||
for (; i < hidden_size; ++i ) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void apply_context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int batch_size = block_tables.size(0);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(key);
|
||||
|
||||
if (head_dim % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
vec_size = 1;
|
||||
}
|
||||
|
||||
int thread_nums = head_num * head_dim / vec_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(max_seq_len_in_batch, batch_size);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
break;
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"context_kv_cache_memcpy",
|
||||
apply_context_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch
|
||||
);)
|
||||
}
|
|
@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
return ;
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
|
@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
}
|
||||
|
||||
for (; i < hidden_size; ++i ) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
|
||||
// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
#include "../common/mp_type_traits.h"
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void apply_emb_rotary_compute(
|
||||
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
|
||||
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||
const int head_num, const int head_dim) {
|
||||
scalar_t x[VecSize];
|
||||
|
@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute(
|
|||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||
|
@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void cos_sin_memory_access(
|
||||
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
|
||||
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
|
||||
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||
const int half_head_dim) {
|
||||
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||
|
@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access(
|
|||
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||
const int shard_head =
|
||||
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
|
||||
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
|
||||
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
|
||||
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void apply_k_rotary_emb_compute(
|
||||
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
|
||||
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||
const int64_t value_stride, const int token_id,
|
||||
|
@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||
|
@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute(
|
|||
block_size, block_offset, head_dim, half_head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
|
@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
|
|||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key and copy kv
|
||||
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
|
@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel(
|
|||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, bool high_precision>
|
||||
void apply_rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
|
@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
int sin_stride = sin.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
|
@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
|
@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
|
@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
|
@ -338,12 +341,12 @@ void apply_rotary_embedding_and_cache_copy(
|
|||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, bool high_precision>
|
||||
void apply_rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
){
|
||||
int num_tokens = query.size(0);
|
||||
int head_num = query.size(1);
|
||||
|
@ -355,6 +358,8 @@ void apply_rotary_embedding(
|
|||
int cos_stride = cos.stride(0);
|
||||
int sin_stride = sin.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
|
@ -373,7 +378,7 @@ void apply_rotary_embedding(
|
|||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
|
@ -389,7 +394,7 @@ void apply_rotary_embedding(
|
|||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
|
@ -405,7 +410,7 @@ void apply_rotary_embedding(
|
|||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
|
@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy(
|
|||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
|
||||
high_precision,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_and_cache_copy",
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t>(
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
|
@ -458,12 +465,14 @@ void rotary_embedding(
|
|||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
at::Tensor& sin, // [total_tokens, head_dim]
|
||||
bool high_precision
|
||||
){
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
|
||||
high_precision,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
apply_rotary_embedding<scalar_t>(
|
||||
apply_rotary_embedding<scalar_t, high_precision>(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
|
|
|
@ -9,11 +9,22 @@ void decode_kv_cache_memcpy(
|
|||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||
torch::Tensor& sin); // [total_tokens, head_dim]
|
||||
torch::Tensor& sin, // [total_tokens, head_dim]
|
||||
bool high_precision);
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
|
@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy(
|
|||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
torch::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision);
|
||||
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
|
@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the context stage.");
|
||||
|
||||
m.def(
|
||||
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
int log2_ceil(int value) {
|
||||
|
|
|
@ -11,16 +11,16 @@ template <typename T, int VecSize>
|
|||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src));
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||
// here.
|
||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src));
|
||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
|
||||
*(reinterpret_cast<float4 *>(dst + 4)) =
|
||||
*(reinterpret_cast<float4 *>(src + 4));
|
||||
*(reinterpret_cast<const float4 *>(src + 4));
|
||||
}
|
||||
|
||||
template <typename T, int VecSize>
|
||||
|
|
|
@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||
for fname in [
|
||||
"cuda/pybind/inference.cpp",
|
||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/context_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2
|
||||
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
@ -10,12 +12,7 @@ inference_ops = InferenceOpsLoader().load()
|
|||
HEAD_DIM = 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
def test_copy_kv_to_caches(
|
||||
def run_decode_copy_kv_to_caches(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
|
@ -61,5 +58,65 @@ def test_copy_kv_to_caches(
|
|||
assert torch.equal(v_target, v_source)
|
||||
|
||||
|
||||
def run_context_copy_kv_to_cache(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
|
||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||
max_seq_len = max_num_blocks_per_seq * block_size
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
if same_context_len:
|
||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache = torch.zeros_like(k_cache_ref)
|
||||
v_cache = torch.zeros_like(v_cache_ref)
|
||||
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
|
||||
)
|
||||
|
||||
assert torch.equal(k_cache, k_cache_ref)
|
||||
assert torch.equal(v_cache, v_cache_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
def test_kv_cache_memcopy(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
|
||||
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||
test_kv_cache_memcopy(4, 32, 8, 16, True)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
@ -10,11 +11,18 @@ from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table
|
|||
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
|
||||
|
||||
def numpy_allclose(x, y, rtol, atol):
|
||||
x_numpy = x.detach().cpu().numpy()
|
||||
y_numpy = y.detach().cpu().numpy()
|
||||
|
||||
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("H", [32])
|
||||
@pytest.mark.parametrize("D", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
torch.manual_seed(10)
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
|
@ -54,17 +62,36 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
|
||||
new_q_copy = new_q.clone()
|
||||
new_k_copy = new_k.clone()
|
||||
|
||||
if dtype == torch.float16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
||||
new_q_fp16 = new_q.clone()
|
||||
new_k_fp16 = new_k.clone()
|
||||
|
||||
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
|
||||
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
|
||||
high_precision_q = new_q.to(torch.float32)
|
||||
high_precision_k = new_k.to(torch.float32)
|
||||
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
|
||||
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
|
||||
|
||||
else:
|
||||
rtol = 1e-5
|
||||
atol = 1e-7
|
||||
|
||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
|
||||
)
|
||||
|
||||
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
|
||||
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
|
||||
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
|
@ -74,18 +101,26 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
|
||||
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
|
||||
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
|
||||
numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
|
||||
|
||||
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
|
||||
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
|
||||
numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
|
||||
numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
|
||||
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
|
||||
numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
|
||||
|
||||
assert v_target.shape == v_source.shape
|
||||
assert torch.equal(v_target, v_source)
|
||||
|
||||
if dtype == torch.float16:
|
||||
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
|
||||
rtol = 1e-3
|
||||
atol = 1e-1
|
||||
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
|
||||
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
|
||||
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(16, 512, 4, 128, torch.float16)
|
||||
test_rotary_emb(16, 64, 4, 128, torch.float16)
|
||||
|
|
Loading…
Reference in New Issue