mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * fused kv copy * fused copy * colossalai/kernel/triton/no_pad_rotary_embedding.py * del padding llama * delpull/5399/head
parent
b21aac5bae
commit
730103819d
|
@ -16,7 +16,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
context_attention_unpadded,
|
context_attention_unpadded,
|
||||||
copy_kv_to_blocked_cache,
|
decoding_fused_rotary_embedding,
|
||||||
flash_decoding_attention,
|
flash_decoding_attention,
|
||||||
get_xine_cache,
|
get_xine_cache,
|
||||||
rotary_embedding,
|
rotary_embedding,
|
||||||
|
@ -281,11 +281,10 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k=key_states,
|
k=key_states,
|
||||||
|
@ -300,8 +299,16 @@ class NopadLlamaAttention(LlamaAttention):
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
copy_kv_to_blocked_cache(
|
decoding_fused_rotary_embedding(
|
||||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cos_sin[0],
|
||||||
|
cos_sin[1],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables,
|
||||||
|
sequence_lengths,
|
||||||
)
|
)
|
||||||
attn_output = flash_decoding_attention(
|
attn_output = flash_decoding_attention(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
|
|
|
@ -1,451 +0,0 @@
|
||||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
LlamaConfig,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
|
||||||
from colossalai.inference.struct import BatchInfo
|
|
||||||
from colossalai.kernel.triton import (
|
|
||||||
context_attention_unpadded,
|
|
||||||
copy_kv_to_blocked_cache,
|
|
||||||
flash_decoding_attention,
|
|
||||||
get_xine_cache,
|
|
||||||
rotary_embedding,
|
|
||||||
)
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
||||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
||||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
||||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
||||||
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
def llama_causal_lm_forward(
|
|
||||||
self: LlamaForCausalLM,
|
|
||||||
batch: BatchInfo = None,
|
|
||||||
k_caches: List[torch.Tensor] = None,
|
|
||||||
v_caches: List[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""This function will replace the forward function of LlamaForCausalLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
|
|
||||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
hidden_states = llama_model_forward(
|
|
||||||
self.model,
|
|
||||||
batch=batch,
|
|
||||||
k_caches=k_caches,
|
|
||||||
v_caches=v_caches,
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
|
||||||
self: LlamaModel,
|
|
||||||
batch: BatchInfo = None,
|
|
||||||
k_caches: List[torch.Tensor] = None,
|
|
||||||
v_caches: List[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""This function will replace the forward function of LlamaModel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
|
||||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
|
||||||
"""
|
|
||||||
input_ids = batch.get_batch_inputs()
|
|
||||||
block_tables = batch.get_block_table_tensor()
|
|
||||||
attention_mask = batch.get_attn_mask()
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if HAS_TRITON:
|
|
||||||
sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
|
||||||
else:
|
|
||||||
sequence_lengths = batch.get_sequence_lengths()
|
|
||||||
|
|
||||||
batch_size, _ = input_ids.shape
|
|
||||||
kv_seq_len = sequence_lengths.max().item()
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if batch.is_prompts:
|
|
||||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
|
||||||
position_ids = generate_padding_position_id(attention_mask)
|
|
||||||
else:
|
|
||||||
position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1)
|
|
||||||
else:
|
|
||||||
if batch.is_prompts:
|
|
||||||
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
else:
|
|
||||||
position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
|
||||||
|
|
||||||
if batch.is_prompts:
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
|
||||||
)
|
|
||||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
|
||||||
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
for layer_id, decoder_layer in enumerate(self.layers):
|
|
||||||
hidden_states = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
block_tables=block_tables,
|
|
||||||
k_cache=k_caches[layer_id],
|
|
||||||
v_cache=v_caches[layer_id],
|
|
||||||
is_prompts=batch.is_prompts,
|
|
||||||
sequence_lengths=sequence_lengths,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
kv_seq_len=kv_seq_len,
|
|
||||||
cos_sin=cos_sin,
|
|
||||||
fd_inter_tensor=batch.fd_inter_tensor,
|
|
||||||
output_tensor=output_tensor,
|
|
||||||
norm_output=norm_output,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.is_prompts:
|
|
||||||
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
|
||||||
norm_output = torch.empty_like(hidden_states)
|
|
||||||
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def llama_decoder_layer_forward(
|
|
||||||
self: LlamaDecoderLayer,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.LongTensor,
|
|
||||||
block_tables: torch.Tensor = None,
|
|
||||||
k_cache: torch.Tensor = None,
|
|
||||||
v_cache: torch.Tensor = None,
|
|
||||||
is_prompts: bool = True,
|
|
||||||
sequence_lengths: torch.Tensor = None,
|
|
||||||
attention_mask: torch.Tensor = None,
|
|
||||||
kv_seq_len: int = 0,
|
|
||||||
cos_sin: Tuple[torch.Tensor] = None,
|
|
||||||
fd_inter_tensor: FDIntermTensors = None,
|
|
||||||
output_tensor: torch.Tensor = None,
|
|
||||||
norm_output: torch.Tensor = None,
|
|
||||||
sm_scale: int = None,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
||||||
position_ids (torch.LongTensor), The position ids of input sequences.
|
|
||||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
|
||||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
|
||||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
|
||||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
|
||||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
|
||||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
|
||||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
|
|
||||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
|
||||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
|
||||||
"""
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
|
||||||
# Self Attention
|
|
||||||
hidden_states = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
block_tables=block_tables,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
is_prompts=is_prompts,
|
|
||||||
sequence_lengths=sequence_lengths,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
kv_seq_len=kv_seq_len,
|
|
||||||
cos_sin=cos_sin,
|
|
||||||
fd_inter_tensor=fd_inter_tensor,
|
|
||||||
output_tensor=output_tensor,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class PadLlamaAttention(LlamaAttention):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: LlamaConfig,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
attn_qproj_w: torch.nn.Parameter = None,
|
|
||||||
attn_kproj_w: torch.nn.Parameter = None,
|
|
||||||
attn_vproj_w: torch.nn.Parameter = None,
|
|
||||||
attn_oproj_w: torch.nn.Parameter = None,
|
|
||||||
):
|
|
||||||
"""This layer will replace the LlamaAttention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (LlamaConfig): Holding the Llama model config.
|
|
||||||
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
|
|
||||||
attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None.
|
|
||||||
attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None.
|
|
||||||
attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None.
|
|
||||||
attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None.
|
|
||||||
"""
|
|
||||||
super().__init__(config, layer_idx)
|
|
||||||
self.q_proj.weight = attn_qproj_w
|
|
||||||
self.k_proj.weight = attn_kproj_w
|
|
||||||
self.v_proj.weight = attn_vproj_w
|
|
||||||
self.o_proj.weight = attn_oproj_w
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
|
||||||
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (LlamaAttention): The origin LlamaAttention layer.
|
|
||||||
"""
|
|
||||||
config = module.config
|
|
||||||
layer_idx = module.layer_idx
|
|
||||||
|
|
||||||
attn_qproj_w = module.q_proj.weight
|
|
||||||
attn_kproj_w = module.k_proj.weight
|
|
||||||
attn_vproj_w = module.v_proj.weight
|
|
||||||
attn_oproj_w = module.o_proj.weight
|
|
||||||
|
|
||||||
attn_layer = PadLlamaAttention(
|
|
||||||
config=config,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
attn_qproj_w=attn_qproj_w,
|
|
||||||
attn_kproj_w=attn_kproj_w,
|
|
||||||
attn_vproj_w=attn_vproj_w,
|
|
||||||
attn_oproj_w=attn_oproj_w,
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_layer
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.LongTensor,
|
|
||||||
block_tables: torch.Tensor = None,
|
|
||||||
k_cache: torch.Tensor = None,
|
|
||||||
v_cache: torch.Tensor = None,
|
|
||||||
is_prompts: bool = True,
|
|
||||||
sequence_lengths: torch.Tensor = None,
|
|
||||||
attention_mask: torch.Tensor = None,
|
|
||||||
kv_seq_len: int = 0,
|
|
||||||
cos_sin: Tuple[torch.Tensor] = None,
|
|
||||||
fd_inter_tensor: FDIntermTensors = None,
|
|
||||||
output_tensor: torch.Tensor = None,
|
|
||||||
sm_scale: int = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]
|
|
||||||
position_ids (torch.LongTensor), The position ids of input sequences.
|
|
||||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
|
||||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
|
||||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
|
||||||
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
|
||||||
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
|
||||||
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size [batch_size, seq_len]
|
|
||||||
where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
|
|
||||||
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
|
||||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
|
||||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
|
||||||
storing intermediate values in flash-decoding. Defaults to None.
|
|
||||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
|
||||||
"""
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
if is_prompts:
|
|
||||||
if attention_mask is not None:
|
|
||||||
query_states, key_states, value_states, indices = unpading_input(
|
|
||||||
query_states, key_states, value_states, attention_mask
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
|
||||||
else:
|
|
||||||
query_states = query_states.squeeze(dim=1)
|
|
||||||
key_states = key_states.squeeze(dim=1)
|
|
||||||
value_states = value_states.squeeze(dim=1)
|
|
||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
|
||||||
|
|
||||||
if is_prompts:
|
|
||||||
attn_output = context_attention_unpadded(
|
|
||||||
q=query_states,
|
|
||||||
k=key_states,
|
|
||||||
v=value_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
context_lengths=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
output=output_tensor,
|
|
||||||
max_seq_len=kv_seq_len,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
|
||||||
else:
|
|
||||||
copy_kv_to_blocked_cache(
|
|
||||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
|
||||||
)
|
|
||||||
attn_output = flash_decoding_attention(
|
|
||||||
q=query_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
kv_seq_len=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
max_seq_len_in_batch=kv_seq_len,
|
|
||||||
output=output_tensor,
|
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
attn_output = attn_output.squeeze(1)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
if is_prompts:
|
|
||||||
attn_output = PagedAttention.pad_context_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = PagedAttention.pad_decoding_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
sequence_lengths,
|
|
||||||
block_tables,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Generate padding position_id through attention mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor` of shape [batch_size, sequence_length]:
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The padding position_id.
|
|
||||||
"""
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
return position_ids
|
|
||||||
|
|
||||||
|
|
||||||
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
|
|
||||||
"""Convert padding input to nopad input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
||||||
k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
||||||
v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim]
|
|
||||||
attention_mask (torch.Tensor): [batch_size, sequence_length]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch.
|
|
||||||
|
|
||||||
"""
|
|
||||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
||||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape
|
|
||||||
q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
||||||
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
||||||
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
|
||||||
return (q, k, v, indices)
|
|
|
@ -13,7 +13,7 @@ if HAS_TRITON:
|
||||||
from .fused_rotary_embedding import fused_rotary_embedding
|
from .fused_rotary_embedding import fused_rotary_embedding
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||||
from .no_pad_rotary_embedding import rotary_embedding
|
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from .rms_layernorm import rms_layernorm
|
from .rms_layernorm import rms_layernorm
|
||||||
from .rotary_cache_copy import get_xine_cache
|
from .rotary_cache_copy import get_xine_cache
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
|
@ -28,4 +28,5 @@ if HAS_TRITON:
|
||||||
"rotary_embedding",
|
"rotary_embedding",
|
||||||
"fused_rotary_embedding",
|
"fused_rotary_embedding",
|
||||||
"get_xine_cache",
|
"get_xine_cache",
|
||||||
|
"decoding_fused_rotary_embedding",
|
||||||
]
|
]
|
||||||
|
|
|
@ -45,21 +45,21 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||||
k = tl.load(K + offsets_kv)
|
k = tl.load(K + offsets_kv)
|
||||||
v = tl.load(V + offsets_kv)
|
v = tl.load(V + offsets_kv)
|
||||||
|
|
||||||
offsets_kvcache = (
|
offsets_kcache = (
|
||||||
block_id * stride_cachekb
|
block_id * stride_cachekb
|
||||||
+ cur_kv_head_idx * stride_cachekh
|
+ cur_kv_head_idx * stride_cachekh
|
||||||
+ offsets_in_last_block * stride_cachekbs
|
+ offsets_in_last_block * stride_cachekbs
|
||||||
+ offsets_dmodel * stride_cachekd
|
+ offsets_dmodel * stride_cachekd
|
||||||
)
|
)
|
||||||
offsets_kvcache = (
|
offsets_vcache = (
|
||||||
block_id * stride_cachevb
|
block_id * stride_cachevb
|
||||||
+ cur_kv_head_idx * stride_cachevh
|
+ cur_kv_head_idx * stride_cachevh
|
||||||
+ offsets_in_last_block * stride_cachevbs
|
+ offsets_in_last_block * stride_cachevbs
|
||||||
+ offsets_dmodel * stride_cachevd
|
+ offsets_dmodel * stride_cachevd
|
||||||
)
|
)
|
||||||
|
|
||||||
tl.store(KCache + offsets_kvcache, k)
|
tl.store(KCache + offsets_kcache, k)
|
||||||
tl.store(VCache + offsets_kvcache, v)
|
tl.store(VCache + offsets_vcache, v)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel(
|
||||||
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
||||||
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
|
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
|
||||||
|
|
||||||
past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
|
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1
|
||||||
|
|
||||||
last_block_idx = past_kv_seq_len // block_size
|
last_block_idx = past_kv_seq_len // block_size
|
||||||
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
|
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
|
||||||
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
|
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
|
||||||
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
||||||
|
|
||||||
kv_range0 = (
|
kv_range0 = (
|
||||||
|
@ -274,6 +274,241 @@ def fused_rotary_embedding_kernel(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_rotary_embedding_kernel_v2(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
kv_cache,
|
||||||
|
BLOCK_TABLES,
|
||||||
|
context_lengths,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
cacheb_stride,
|
||||||
|
cacheh_stride,
|
||||||
|
cachebs_stride,
|
||||||
|
cached_stride,
|
||||||
|
bts_stride,
|
||||||
|
btb_stride,
|
||||||
|
block_size,
|
||||||
|
q_total_tokens,
|
||||||
|
Q_HEAD_NUM: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_head_index = tl.program_id(0)
|
||||||
|
if block_head_index >= Q_HEAD_NUM:
|
||||||
|
return
|
||||||
|
block_token_index = tl.program_id(1)
|
||||||
|
|
||||||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||||
|
|
||||||
|
off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
|
||||||
|
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
|
||||||
|
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
|
||||||
|
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride
|
||||||
|
|
||||||
|
loaded_q0 = tl.load(
|
||||||
|
q + off_q0,
|
||||||
|
)
|
||||||
|
loaded_q1 = tl.load(
|
||||||
|
q + off_q1,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k0 = tl.load(
|
||||||
|
k + off_k0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k1 = tl.load(
|
||||||
|
k + off_k1,
|
||||||
|
)
|
||||||
|
|
||||||
|
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
|
||||||
|
|
||||||
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
||||||
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
||||||
|
|
||||||
|
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
|
||||||
|
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
|
||||||
|
|
||||||
|
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
|
||||||
|
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
|
||||||
|
|
||||||
|
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
|
||||||
|
|
||||||
|
last_block_idx = past_kv_seq_len // block_size
|
||||||
|
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
|
||||||
|
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
|
||||||
|
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
||||||
|
|
||||||
|
kv_range0 = (
|
||||||
|
block_ids * cacheb_stride
|
||||||
|
+ block_head_index * cacheh_stride
|
||||||
|
+ offsets_in_last_block
|
||||||
|
+ dim_range0 * cached_stride
|
||||||
|
)
|
||||||
|
kv_range1 = (
|
||||||
|
block_ids * cacheb_stride
|
||||||
|
+ block_head_index * cacheh_stride
|
||||||
|
+ offsets_in_last_block
|
||||||
|
+ dim_range1 * cached_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
kv_cache + kv_range0,
|
||||||
|
out_k0,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
kv_cache + kv_range1,
|
||||||
|
out_k1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# concat
|
||||||
|
tl.store(
|
||||||
|
q + off_q0,
|
||||||
|
out_q0,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
q + off_q1,
|
||||||
|
out_q1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def decoding_fused_rotary_embedding_kernel(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
BLOCK_TABLES,
|
||||||
|
context_lengths,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
cache_b_stride,
|
||||||
|
cache_h_stride,
|
||||||
|
cache_bs_stride,
|
||||||
|
cache_d_stride,
|
||||||
|
bts_stride,
|
||||||
|
btb_stride,
|
||||||
|
block_size,
|
||||||
|
Q_HEAD_NUM: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_head_index = tl.program_id(0)
|
||||||
|
if block_head_index >= Q_HEAD_NUM:
|
||||||
|
return
|
||||||
|
|
||||||
|
block_token_index = tl.program_id(1)
|
||||||
|
|
||||||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||||
|
total_dim_range = tl.arange(0, HEAD_DIM)
|
||||||
|
|
||||||
|
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride
|
||||||
|
off_q0 = q_off_base + dim_range0 * head_dim_stride
|
||||||
|
off_q1 = q_off_base + dim_range1 * head_dim_stride
|
||||||
|
|
||||||
|
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
|
||||||
|
off_k0 = off_base + dim_range0 * head_dim_stride
|
||||||
|
off_k1 = off_base + dim_range1 * head_dim_stride
|
||||||
|
|
||||||
|
off_v = off_base + total_dim_range * head_dim_stride
|
||||||
|
|
||||||
|
loaded_q0 = tl.load(
|
||||||
|
q + off_q0,
|
||||||
|
)
|
||||||
|
loaded_q1 = tl.load(
|
||||||
|
q + off_q1,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k0 = tl.load(
|
||||||
|
k + off_k0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k1 = tl.load(
|
||||||
|
k + off_k1,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_v = tl.load(
|
||||||
|
v + off_v,
|
||||||
|
)
|
||||||
|
|
||||||
|
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
|
||||||
|
|
||||||
|
loaded_cos = tl.load(cos + off_cos_sin)
|
||||||
|
loaded_sin = tl.load(sin + off_cos_sin)
|
||||||
|
|
||||||
|
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
|
||||||
|
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
|
||||||
|
|
||||||
|
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
|
||||||
|
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
|
||||||
|
|
||||||
|
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
|
||||||
|
|
||||||
|
last_block_idx = past_kv_seq_len // block_size
|
||||||
|
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride)
|
||||||
|
offsets_in_last_block = past_kv_seq_len % block_size
|
||||||
|
|
||||||
|
k_range0 = (
|
||||||
|
block_ids * cache_b_stride
|
||||||
|
+ block_head_index * cache_h_stride
|
||||||
|
+ offsets_in_last_block * cache_bs_stride
|
||||||
|
+ dim_range0 * cache_d_stride
|
||||||
|
)
|
||||||
|
k_range1 = (
|
||||||
|
block_ids * cache_b_stride
|
||||||
|
+ block_head_index * cache_h_stride
|
||||||
|
+ offsets_in_last_block * cache_bs_stride
|
||||||
|
+ dim_range1 * cache_d_stride
|
||||||
|
)
|
||||||
|
v_range = (
|
||||||
|
block_ids * cache_b_stride
|
||||||
|
+ block_head_index * cache_h_stride
|
||||||
|
+ offsets_in_last_block * cache_bs_stride
|
||||||
|
+ total_dim_range * cache_d_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
v_cache + v_range,
|
||||||
|
loaded_v,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
k_cache + k_range0,
|
||||||
|
out_k0,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
k_cache + k_range1,
|
||||||
|
out_k1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# concat
|
||||||
|
tl.store(
|
||||||
|
q + off_q0,
|
||||||
|
out_q0,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
q + off_q1,
|
||||||
|
out_q1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rotary_embedding(
|
def rotary_embedding(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
|
@ -297,12 +532,13 @@ def rotary_embedding(
|
||||||
assert q.size(0) == k.size(0)
|
assert q.size(0) == k.size(0)
|
||||||
BLOCK_HEAD = 4
|
BLOCK_HEAD = 4
|
||||||
BLOCK_TOKENS = 4
|
BLOCK_TOKENS = 4
|
||||||
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))
|
|
||||||
|
|
||||||
if head_dim >= 256:
|
if head_dim >= 1024:
|
||||||
num_warps = 32
|
num_warps = 32
|
||||||
elif head_dim >= 128:
|
elif head_dim >= 512:
|
||||||
num_warps = 16
|
num_warps = 16
|
||||||
|
elif head_dim >= 256:
|
||||||
|
num_warps = 8
|
||||||
else:
|
else:
|
||||||
num_warps = 4
|
num_warps = 4
|
||||||
|
|
||||||
|
@ -318,6 +554,10 @@ def rotary_embedding(
|
||||||
cos_token_stride = cos.stride(0)
|
cos_token_stride = cos.stride(0)
|
||||||
cos_stride = cos.stride(1)
|
cos_stride = cos.stride(1)
|
||||||
if k_cache == None:
|
if k_cache == None:
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
|
||||||
|
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
|
||||||
|
)
|
||||||
rotary_embedding_kernel[grid](
|
rotary_embedding_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
@ -339,7 +579,8 @@ def rotary_embedding(
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fused_rotary_embedding_kernel[grid](
|
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
|
||||||
|
fused_rotary_embedding_kernel_v2[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
cos,
|
cos,
|
||||||
|
@ -363,10 +604,85 @@ def rotary_embedding(
|
||||||
k_cache.size(-2),
|
k_cache.size(-2),
|
||||||
q_total_tokens,
|
q_total_tokens,
|
||||||
Q_HEAD_NUM=q_head_num,
|
Q_HEAD_NUM=q_head_num,
|
||||||
K_HEAD_NUM=k_head_num,
|
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
BLOCK_HEAD=BLOCK_HEAD,
|
|
||||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def decoding_fused_rotary_embedding(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
k_cache: Optional[torch.Tensor] = None,
|
||||||
|
v_cache: Optional[torch.Tensor] = None,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
kv_lengths: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
q: query tensor, [total_tokens, head_num, head_dim]
|
||||||
|
k: key tensor, [total_tokens, head_num, head_dim]
|
||||||
|
v: value tensor, [total tokens, head_num, head_dim]
|
||||||
|
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
||||||
|
sin: sine for rotary embedding, [max_position_len, head_dim]
|
||||||
|
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
|
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
|
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
|
||||||
|
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
|
||||||
|
"""
|
||||||
|
q_total_tokens, q_head_num, head_dim = q.shape
|
||||||
|
assert q.size(0) == k.size(0) == v.size(0)
|
||||||
|
assert q.size(1) == k.size(1) == v.size(1)
|
||||||
|
assert k_cache.size(-1) == v_cache.size(-1)
|
||||||
|
|
||||||
|
if head_dim >= 1024:
|
||||||
|
num_warps = 32
|
||||||
|
elif head_dim >= 512:
|
||||||
|
num_warps = 16
|
||||||
|
elif head_dim >= 256:
|
||||||
|
num_warps = 8
|
||||||
|
else:
|
||||||
|
num_warps = 4
|
||||||
|
|
||||||
|
q_token_stride = q.stride(0)
|
||||||
|
q_head_stride = q.stride(1)
|
||||||
|
head_dim_stride = q.stride(2)
|
||||||
|
|
||||||
|
k_token_stride = k.stride(0)
|
||||||
|
k_head_stride = k.stride(1)
|
||||||
|
|
||||||
|
cos_token_stride = cos.stride(0)
|
||||||
|
cos_stride = cos.stride(1)
|
||||||
|
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
|
||||||
|
decoding_fused_rotary_embedding_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables,
|
||||||
|
kv_lengths,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
block_tables.stride(0),
|
||||||
|
block_tables.stride(1),
|
||||||
|
k_cache.size(-2),
|
||||||
|
Q_HEAD_NUM=q_head_num,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
|
@ -204,7 +204,7 @@ def benchmark_inference(args):
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
if args.profile:
|
if args.profile:
|
||||||
ctx.step()
|
ctx.step()
|
||||||
|
print(f"config:batch_size {args.batch_size}, input_len{ args.seq_len}, output_len {args.output_len}")
|
||||||
print_details_info(model.config, args, whole_end2end, total_token_num)
|
print_details_info(model.config, args, whole_end2end, total_token_num)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
ROOT=$(realpath $(dirname $0))
|
ROOT=$(realpath $(dirname $0))
|
||||||
|
echo $ROOT
|
||||||
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
||||||
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
||||||
mode=$1
|
mode="colossalai"
|
||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
|
|
||||||
|
@ -23,10 +24,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||||
|
|
||||||
# benchmark llama2-7b one single GPU
|
# benchmark llama2-7b one single GPU
|
||||||
for input_len in 128 512 1024; do
|
for input_len in 128 512 1024; do
|
||||||
for output_len in 128 256; do
|
for output_len in 128 256; do
|
||||||
for bsz in 16 32 64; do
|
for bsz in 16 32 64; do
|
||||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --model_path "/home/caidi/llama_model/" | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
|
@ -3,8 +3,8 @@ import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.triton import rotary_embedding
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
|
@ -67,25 +67,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||||
)
|
)
|
||||||
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||||
new_q = torch.randn_like(new_k)
|
new_q = torch.randn_like(new_k)
|
||||||
|
new_v = torch.randn_like(new_k)
|
||||||
|
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
block_tables = block_tables.to(device="cuda")
|
block_tables = block_tables.to(device="cuda")
|
||||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
|
||||||
|
|
||||||
rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
|
decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths)
|
||||||
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||||
assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4)
|
|
||||||
|
|
||||||
# check one by one
|
|
||||||
for seq_i in range(BATCH_SIZE):
|
|
||||||
ki = new_k[seq_i]
|
|
||||||
ki = ki.squeeze()
|
|
||||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
|
||||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
|
||||||
offsets_in_block = past_kv_seq_len % block_size
|
|
||||||
target = k_cache[target_block_id, :, offsets_in_block, :]
|
|
||||||
orig = new_k[seq_i].squeeze(dim=0)
|
|
||||||
assert torch.equal(orig, target)
|
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
BATCH = 16
|
||||||
|
@ -94,8 +83,8 @@ configs = [
|
||||||
x_names=["num_tokens"],
|
x_names=["num_tokens"],
|
||||||
x_vals=[2**i for i in range(4, 11)],
|
x_vals=[2**i for i in range(4, 11)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
|
@ -110,23 +99,53 @@ def benchmark_rotary_emb(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
):
|
):
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
SEQ_LEN = num_tokens // BATCH_SIZE
|
||||||
|
max_num_blocks_per_seq = 8
|
||||||
|
block_size = 64
|
||||||
warmup = 10
|
warmup = 10
|
||||||
rep = 100
|
rep = 100
|
||||||
|
|
||||||
head_dim = 128
|
head_dim = 4096
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
||||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
v = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
cos_shape = (num_tokens, head_dim // 2)
|
cos_shape = (num_tokens, head_dim // 2)
|
||||||
|
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
if provider == "torch_rotary_emb_func":
|
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||||
fn = lambda: torch_rotary_emb(q, cos, sin)
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
elif provider == "triton_rotary_emb_func":
|
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||||
fn = lambda: rotary_embedding(q, k, cos, sin)
|
)
|
||||||
|
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
||||||
|
new_q = torch.randn_like(new_k)
|
||||||
|
new_v = torch.randn_like(new_k)
|
||||||
|
|
||||||
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
|
||||||
|
if provider == "no_fused_rotary_emb_func":
|
||||||
|
fn = lambda: [
|
||||||
|
rotary_embedding(new_q, new_k, cos, sin),
|
||||||
|
copy_kv_to_blocked_cache(
|
||||||
|
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
|
||||||
|
),
|
||||||
|
]
|
||||||
|
elif provider == "fused_triton_rotary_emb_func":
|
||||||
|
fn = lambda: decoding_fused_rotary_embedding(
|
||||||
|
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider")
|
raise ValueError("Undefined provider")
|
||||||
|
|
||||||
|
@ -136,4 +155,4 @@ def benchmark_rotary_emb(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||||
# benchmark_rotary_emb.run(save_path=".",print_data=True)
|
# benchmark_rotary_emb.run(save_path=".", print_data=True)
|
||||||
|
|
Loading…
Reference in New Issue