[fix] merge conflicts

pull/5434/head
Runyu Lu 2024-03-25 14:48:28 +08:00
commit 68e9396bc0
15 changed files with 544 additions and 132 deletions

View File

@ -88,7 +88,7 @@ class InferenceConfig:
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
# NOTE: arrange configs according to their importance and frequency of usage # NOTE: arrange configs according to their importance and frequency of usage
@ -122,6 +122,7 @@ class InferenceConfig:
pp_size: int = 1 pp_size: int = 1
micro_batch_size: int = 1 micro_batch_size: int = 1
micro_batch_buffer_size: int = None micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
# cuda kernel option # cuda kernel option
use_cuda_kernel: bool = False use_cuda_kernel: bool = False
@ -149,6 +150,10 @@ class InferenceConfig:
self.dtype in _ALLOWED_DTYPES self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# skip using casting when the data type is float32
if self.dtype == torch.float32:
self.high_precision = False
# check distributed # check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size() self.tp_size * self.pp_size == dist.get_world_size()

View File

@ -61,6 +61,7 @@ class InferenceEngine:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision
model = model.eval() model = model.eval()
model = model.cuda() model = model.cuda()
model.to(self.dtype) model.to(self.dtype)
@ -150,8 +151,10 @@ class InferenceEngine:
batch_size=batch_size, batch_size=batch_size,
is_prompts=False, is_prompts=False,
use_cuda_graph=True, use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(), kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim, head_dim=head_dim,
dtype=self.dtype,
) )
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
@ -391,8 +394,10 @@ class InferenceEngine:
is_prompts=batch.is_prompts, is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel, use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph, use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(), kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim, head_dim=batch.head_dim,
dtype=batch.dtype,
) )
return input_ids, output_tensor, input_meta_data return input_ids, output_tensor, input_meta_data
@ -421,7 +426,6 @@ class InferenceEngine:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)

View File

@ -2,6 +2,7 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaConfig, LlamaConfig,
@ -30,10 +31,12 @@ inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
try: try:
HAS_TRITON = True from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError: except ImportError:
HAS_TRITON = False use_flash_attn2 = False
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
def llama_causal_lm_forward( def llama_causal_lm_forward(
@ -47,9 +50,10 @@ def llama_causal_lm_forward(
"""This function will replace the forward function of LlamaForCausalLM. """This function will replace the forward function of LlamaForCausalLM.
Args: Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@ -61,6 +65,7 @@ def llama_causal_lm_forward(
k_caches=k_caches, k_caches=k_caches,
v_caches=v_caches, v_caches=v_caches,
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
high_precision=inputmetadata.high_precision,
) )
logits = torch.mm(hidden_states, self.lm_head.weight) logits = torch.mm(hidden_states, self.lm_head.weight)
return logits return logits
@ -74,13 +79,15 @@ def llama_model_forward(
k_caches: List[torch.Tensor] = None, k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None,
use_cuda_kernel: Optional[bool] = True, use_cuda_kernel: Optional[bool] = True,
high_precision: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""This function will replace the forward function of LlamaModel. """This function will replace the forward function of LlamaModel.
Args: Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. batch (BatchInfo): It stores the necessary input information for this inference.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
block_tables = inputmetadata.block_tables block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths sequence_lengths = inputmetadata.sequence_lengths
@ -94,6 +101,10 @@ def llama_model_forward(
use_cuda_kernel = False use_cuda_kernel = False
hidden_states = self.embed_tokens(input_tokens_ids) hidden_states = self.embed_tokens(input_tokens_ids)
if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2:
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
else:
cu_seqlens = None
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
@ -111,13 +122,15 @@ def llama_model_forward(
v_cache=v_caches[layer_id], v_cache=v_caches[layer_id],
is_prompts=inputmetadata.is_prompts, is_prompts=inputmetadata.is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=inputmetadata.fd_inter_tensor, fd_inter_tensor=inputmetadata.fd_inter_tensor,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor, output_tensor=output_tensor,
norm_output=norm_output, norm_output=norm_output,
sm_scale=sm_scale, sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel, use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
) )
if inputmetadata.is_prompts: if inputmetadata.is_prompts:
@ -134,38 +147,42 @@ def llama_decoder_layer_forward(
self: LlamaDecoderLayer, self: LlamaDecoderLayer,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
block_tables: torch.Tensor = None, block_tables: torch.Tensor,
k_cache: torch.Tensor = None, k_cache: torch.Tensor,
v_cache: torch.Tensor = None, v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None, norm_output: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
use_cuda_kernel: bool = True, use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer. """This function will replace the forward function of LlamaDecoderLayer.
Args: Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None. storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. v_cache (torch.Tensor): It holds the GPU memory for the key cache.
sequence_lengths (torch.Tensor): Holding the sequence length of each sequence.
cos_sin (Tuple[torch.Tensor]): Holding cos and sin.
fd_inter_tensor (FDIntermTensors): Holding tensors used for
storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
@ -175,14 +192,16 @@ def llama_decoder_layer_forward(
block_tables=block_tables, block_tables=block_tables,
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
is_prompts=is_prompts,
sequence_lengths=sequence_lengths, sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin, cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor, fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor, output_tensor=output_tensor,
sm_scale=sm_scale, sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel, use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
) )
# Fully Connected # Fully Connected
@ -276,43 +295,48 @@ class NopadLlamaAttention(LlamaAttention):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
block_tables: torch.Tensor = None, block_tables: torch.Tensor,
k_cache: torch.Tensor = None, k_cache: torch.Tensor,
v_cache: torch.Tensor = None, v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True, is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0, kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None, output_tensor: torch.Tensor = None,
sm_scale: int = None, sm_scale: int = None,
use_cuda_kernel: bool = True, use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
""" """
Args: Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None. storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. v_cache (torch.Tensor): It holds the GPU memory for the key cache.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None. storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
""" """
token_nums = hidden_states.size(0)
if self.num_heads != self.num_key_value_heads: if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim)
else: else:
# fused qkv # fused qkv
token_nums = hidden_states.size(0)
hidden_states = hidden_states.expand(3, -1, -1) hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = ( query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
@ -321,23 +345,41 @@ class NopadLlamaAttention(LlamaAttention):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
if use_cuda_kernel: if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) # flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else: else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded( attn_output = context_attention_unpadded(
q=query_states, q=query_states,
k=key_states, k=key_states,
v=value_states, v=value_states,
k_cache=k_cache, k_cache=k_cache,
v_cache=v_cache, v_cache=v_cache,
context_lengths=sequence_lengths, context_lengths=sequence_lengths,
block_tables=block_tables, block_tables=block_tables,
block_size=block_size, block_size=block_size,
output=output_tensor, output=output_tensor,
max_seq_len=kv_seq_len, max_seq_len=kv_seq_len,
sm_scale=sm_scale, sm_scale=sm_scale,
) )
else: else:
if use_cuda_kernel: if use_cuda_kernel:
inference_ops.rotary_embedding_and_cache_copy( inference_ops.rotary_embedding_and_cache_copy(
@ -350,6 +392,7 @@ class NopadLlamaAttention(LlamaAttention):
v_cache, v_cache,
sequence_lengths, sequence_lengths,
block_tables, block_tables,
high_precision,
) )
else: else:
decoding_fused_rotary_embedding( decoding_fused_rotary_embedding(
@ -435,6 +478,5 @@ class NopadLlamaMLP(LlamaMLP):
""" """
hidden_states = hidden_states.expand(2, -1, -1) hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) act_out = inference_ops.silu_and_mul(gate_up_proj_out)
tmp_out = act_out * gate_up_proj_out[1] return torch.mm(act_out, self.down_proj_weight)
return torch.mm(tmp_out, self.down_proj_weight)

View File

@ -136,7 +136,8 @@ def benchmark_inference(args):
data = data_gen(mbsz, args.seq_len) data = data_gen(mbsz, args.seq_len)
data = data.tolist() if args.mode == "colossalai" or args.mode == "vllm":
data = data.tolist()
generation_config = GenerationConfig( generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,

View File

@ -56,6 +56,23 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \ switch (TYPEIN) { \
case at::ScalarType::Float: { \ case at::ScalarType::Float: { \

View File

@ -27,5 +27,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float; using Type = float;
}; };
template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;
template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};
template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
};
} // namespace common } // namespace common
} // namespace colossalAI } // namespace colossalAI

View File

@ -0,0 +1,195 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ cu_seqlens,
const int* __restrict__ block_tables,
const int head_num,
const int head_dim,
const int block_size,
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}
const int block_offset = seq_token_id % block_size;
const int hidden_size = head_num * head_dim;
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
template<typename scalar_t>
void apply_context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(key);
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));
switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}

View File

@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel(
return ; return ;
} }
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) { int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim; const int head_id = i / head_dim;
const int head_offset = i % head_dim; const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i; const int64_t key_src_id = seq_id * key_stride + i;
@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id); copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
} }
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
} }
template<typename scalar_t> template<typename scalar_t>

View File

@ -1,14 +1,15 @@
// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "utils/vector_copy_utils.h" #include "utils/vector_copy_utils.h"
#include "../common/micros.h" #include "../common/micros.h"
#include "../common/mp_type_traits.h"
template <typename scalar_t, int VecSize> template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_emb_rotary_compute( __device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr, scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, const int64_t stride, const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim, const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) { const int head_num, const int head_dim) {
scalar_t x[VecSize]; scalar_t x[VecSize];
@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute(
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset]; static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset]; static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
} }
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x); copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy(
} }
} }
template <typename scalar_t, int VecSize> template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void cos_sin_memory_access( __device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id, m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride, const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) { const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access(
const int shard_offset = (i % shard_block_size) / VecSize; const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head = const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32; (i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i]; cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i]; sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
} }
} }
template <typename scalar_t, int VecSize> template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_k_rotary_emb_compute( __device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value, scalar_t* __restrict__ key, scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths, const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride, const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id, const int64_t value_stride, const int token_id,
@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute(
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; j++) { for (int j = 0; j < VecSize; j++) {
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
y[j] * sin_ptr[j * 32 + shard_offset]; static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
x[j] * sin_ptr[j * 32 + shard_offset]; static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
} }
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x); copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute(
block_size, block_offset, head_dim, half_head_dim); block_size, block_offset, head_dim, half_head_dim);
} }
template<typename scalar_t, int VecSize> template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel( __global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query, scalar_t* __restrict__ query,
scalar_t* __restrict__ key, scalar_t* __restrict__ key,
@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
extern __shared__ char shard_ptr[]; extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr; m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num; m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy // apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads(); __syncthreads();
//compute query //compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv //compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
} }
template<typename scalar_t, int VecSize> template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, scalar_t* __restrict__ query,
scalar_t* __restrict__ key, scalar_t* __restrict__ key,
@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel(
extern __shared__ char shard_ptr[]; extern __shared__ char shard_ptr[];
scalar_t *cos_ptr = (scalar_t*)shard_ptr; m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
scalar_t *sin_ptr = cos_ptr + half_shard_element_num; m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy // apply cos_sin memcopy
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads(); __syncthreads();
//compute query //compute query
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key //compute key
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
} }
template<typename scalar_t> template<typename scalar_t, bool high_precision>
void apply_rotary_embedding_and_cache_copy( void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim] at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy(
int sin_stride = sin.stride(0); int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0); int block_table_stride = block_tables.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query); int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) { if ((head_dim / 2) % vec_size != 0) {
@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy(
switch (vec_size) { switch (vec_size) {
case 1: case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy(
); );
break; break;
case 2: case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy(
); );
break; break;
case 4: case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
@ -338,12 +341,12 @@ void apply_rotary_embedding_and_cache_copy(
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
template<typename scalar_t> template<typename scalar_t, bool high_precision>
void apply_rotary_embedding( void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim] at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim] at::Tensor& sin // [total_tokens, head_dim]
){ ){
int num_tokens = query.size(0); int num_tokens = query.size(0);
int head_num = query.size(1); int head_num = query.size(1);
@ -355,6 +358,8 @@ void apply_rotary_embedding(
int cos_stride = cos.stride(0); int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0); int sin_stride = sin.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query); int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) { if ((head_dim / 2) % vec_size != 0) {
@ -373,7 +378,7 @@ void apply_rotary_embedding(
switch (vec_size) { switch (vec_size) {
case 1: case 1:
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<scalar_t>(),
@ -389,7 +394,7 @@ void apply_rotary_embedding(
); );
break; break;
case 2: case 2:
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<scalar_t>(),
@ -405,7 +410,7 @@ void apply_rotary_embedding(
); );
break; break;
case 4: case 4:
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>( rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(), cos.data_ptr<scalar_t>(),
@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy(
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size] at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len] at::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision)
{ {
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(), query.scalar_type(),
"rotary_embedding_and_cache_copy", "rotary_embedding_and_cache_copy",
apply_rotary_embedding_and_cache_copy<scalar_t>( apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
query, query,
key, key,
value, value,
@ -458,12 +465,14 @@ void rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim] at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim] at::Tensor& sin, // [total_tokens, head_dim]
bool high_precision
){ ){
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(), query.scalar_type(),
"rotary_embedding", "rotary_embedding",
apply_rotary_embedding<scalar_t>( apply_rotary_embedding<scalar_t, high_precision>(
query, query,
key, key,
cos, cos,

View File

@ -9,11 +9,22 @@ void decode_kv_cache_memcpy(
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len] torch::Tensor& block_tables); // [batch_size, max_seq_len]
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch);
void rotary_embedding( void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim] torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim] torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin); // [total_tokens, head_dim] torch::Tensor& sin, // [total_tokens, head_dim]
bool high_precision);
void rotary_embedding_and_cache_copy( void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim] torch::Tensor& query, // [num_tokens, head_num, head_dim]
@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim] value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len] torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);
torch::Tensor silu_and_mul(const torch::Tensor& ins); torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size] void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage."); "Copy the GPU memory of kvcache during the decode stage.");
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the context stage.");
m.def( m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"performing Rotary Embedding-related calculations and KVCache Memcopy."); "performing Rotary Embedding-related calculations and KVCache Memcopy.");

View File

@ -11,6 +11,8 @@
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include "utils/vector_copy_utils.h"
namespace { namespace {
int log2_ceil(int value) { int log2_ceil(int value) {

View File

@ -11,16 +11,16 @@ template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) { __device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type; using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer // Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src)); *(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
} }
template <> template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) { __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4 // Since the maximum memory alignment length is 128 bits, we choose float4
// here. // here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src)); *(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) = *(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<float4 *>(src + 4)); *(reinterpret_cast<const float4 *>(src + 4));
} }
template <typename T, int VecSize> template <typename T, int VecSize>

View File

@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
for fname in [ for fname in [
"cuda/pybind/inference.cpp", "cuda/pybind/inference.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu", "cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/context_kv_cache_memcpy_kernel.cu",
"cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/fused_rotary_emb_and_cache_kernel.cu",
"cuda/activation_kernel.cu", "cuda/activation_kernel.cu",
"cuda/rms_layernorm_kernel.cu", "cuda/rms_layernorm_kernel.cu",

View File

@ -1,8 +1,10 @@
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
inference_ops = InferenceOpsLoader().load() inference_ops = InferenceOpsLoader().load()
@ -10,12 +12,7 @@ inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4 HEAD_DIM = 4
@pytest.mark.parametrize("bsz", [4, 7, 32]) def run_decode_copy_kv_to_caches(
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_copy_kv_to_caches(
bsz: int, bsz: int,
block_size: int, block_size: int,
max_num_blocks_per_seq: int, max_num_blocks_per_seq: int,
@ -61,5 +58,65 @@ def test_copy_kv_to_caches(
assert torch.equal(v_target, v_source) assert torch.equal(v_target, v_source)
def run_context_copy_kv_to_cache(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
torch.manual_seed(123)
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
max_seq_len = max_num_blocks_per_seq * block_size
dtype = torch.float16
device = get_current_device()
if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
k_cache = torch.zeros_like(k_cache_ref)
v_cache = torch.zeros_like(v_cache_ref)
inference_ops.context_kv_cache_memcpy(
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
)
assert torch.equal(k_cache, k_cache_ref)
assert torch.equal(v_cache, v_cache_ref)
@pytest.mark.parametrize("bsz", [4, 7, 32])
@pytest.mark.parametrize("block_size", [16, 32, 64])
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
def test_kv_cache_memcopy(
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
):
run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len)
if __name__ == "__main__": if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True) test_kv_cache_memcopy(4, 32, 8, 16, True)

View File

@ -1,3 +1,4 @@
import numpy as np
import pytest import pytest
import torch import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
@ -10,11 +11,18 @@ from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64]) @pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32]) @pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64]) @pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
torch.manual_seed(10) torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
@ -54,17 +62,36 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
kv_seq_lengths = past_kv_seq_lengths + 1 kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda") block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
new_q_copy = new_q.clone() new_q_copy = new_q.clone()
new_k_copy = new_k.clone() new_k_copy = new_k.clone()
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
new_q_fp16 = new_q.clone()
new_k_fp16 = new_k.clone()
high_precision_cos = cos[:BATCH_SIZE].to(torch.float32)
high_precision_sin = sin[:BATCH_SIZE].to(torch.float32)
high_precision_q = new_q.to(torch.float32)
high_precision_k = new_k.to(torch.float32)
q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16)
k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
inference_ops.rotary_embedding_and_cache_copy( inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True
) )
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin) inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True)
past_kv_seq_len = kv_seq_lengths - 1 past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
@ -74,18 +101,26 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
v_source = new_v.squeeze() v_source = new_v.squeeze()
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6) numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6) numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol)
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6) numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol)
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6) numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol)
assert k_target.shape == k_source.shape assert k_target.shape == k_source.shape
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6) numpy_allclose(k_target, k_source, rtol=rtol, atol=atol)
assert v_target.shape == v_source.shape assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source) assert torch.equal(v_target, v_source)
if dtype == torch.float16:
# After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test.
rtol = 1e-3
atol = 1e-1
inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False)
numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol)
numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol)
if __name__ == "__main__": if __name__ == "__main__":
test_rotary_emb(16, 512, 4, 128, torch.float16) test_rotary_emb(16, 64, 4, 128, torch.float16)