From 730103819dc0636c85af1af80cc17914dcf196c1 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:31:48 +0800 Subject: [PATCH] [Inference]Fused kv copy into rotary calculation (#5383) * revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * fused kv copy * fused copy * colossalai/kernel/triton/no_pad_rotary_embedding.py * del padding llama * del --- .../modeling/models/nopadding_llama.py | 17 +- .../modeling/models/padding_llama.py | 451 ------------------ colossalai/kernel/triton/__init__.py | 3 +- colossalai/kernel/triton/kvcache_copy.py | 8 +- .../kernel/triton/no_pad_rotary_embedding.py | 334 ++++++++++++- examples/inference/benchmark_llama.py | 2 +- examples/inference/run_benchmark.sh | 7 +- .../triton/test_rotary_embdding_unpad.py | 67 ++- 8 files changed, 391 insertions(+), 498 deletions(-) delete mode 100644 colossalai/inference/modeling/models/padding_llama.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 6b6a5876b..4dfe6dbd7 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -16,7 +16,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.kernel.triton import ( context_attention_unpadded, - copy_kv_to_blocked_cache, + decoding_fused_rotary_embedding, flash_decoding_attention, get_xine_cache, rotary_embedding, @@ -281,11 +281,10 @@ class NopadLlamaAttention(LlamaAttention): torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) ) - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - block_size = k_cache.size(-2) if is_prompts: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( q=query_states, k=key_states, @@ -300,8 +299,16 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: - copy_kv_to_blocked_cache( - key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, ) attn_output = flash_decoding_attention( q=query_states, diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py deleted file mode 100644 index 63050cd6d..000000000 --- a/colossalai/inference/modeling/models/padding_llama.py +++ /dev/null @@ -1,451 +0,0 @@ -# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py -from typing import List, Optional, Tuple - -import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaConfig, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, -) - -from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.modeling.layers.attention import PagedAttention -from colossalai.inference.struct import BatchInfo -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_kv_to_blocked_cache, - flash_decoding_attention, - get_xine_cache, - rotary_embedding, -) -from colossalai.logging import get_dist_logger - -from flash_attn.bert_padding import index_first_axis, pad_input # noqa - -logger = get_dist_logger(__name__) - -try: - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def llama_causal_lm_forward( - self: LlamaForCausalLM, - batch: BatchInfo = None, - k_caches: List[torch.Tensor] = None, - v_caches: List[torch.Tensor] = None, -): - """This function will replace the forward function of LlamaForCausalLM. - - Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. - """ - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - hidden_states = llama_model_forward( - self.model, - batch=batch, - k_caches=k_caches, - v_caches=v_caches, - ) - logits = self.lm_head(hidden_states) - return logits - - -def llama_model_forward( - self: LlamaModel, - batch: BatchInfo = None, - k_caches: List[torch.Tensor] = None, - v_caches: List[torch.Tensor] = None, -): - """This function will replace the forward function of LlamaModel. - - Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. - """ - input_ids = batch.get_batch_inputs() - block_tables = batch.get_block_table_tensor() - attention_mask = batch.get_attn_mask() - - if attention_mask is not None: - if HAS_TRITON: - sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32) - else: - sequence_lengths = batch.get_sequence_lengths() - else: - sequence_lengths = batch.get_sequence_lengths() - - batch_size, _ = input_ids.shape - kv_seq_len = sequence_lengths.max().item() - - if attention_mask is not None: - if batch.is_prompts: - # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. - position_ids = generate_padding_position_id(attention_mask) - else: - position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) - else: - if batch.is_prompts: - position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device) - position_ids = position_ids.unsqueeze(0) - else: - position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device) - position_ids = position_ids.unsqueeze(0) - - hidden_states = self.embed_tokens(input_ids) - - cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) - - if batch.is_prompts: - output_tensor = torch.zeros( - (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device - ) - else: - output_tensor = torch.zeros( - (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device - ) - sm_scale = 1.0 / (batch.head_dim**0.5) - - norm_output = torch.empty_like(hidden_states) - - for layer_id, decoder_layer in enumerate(self.layers): - hidden_states = decoder_layer( - hidden_states, - position_ids=position_ids, - block_tables=block_tables, - k_cache=k_caches[layer_id], - v_cache=v_caches[layer_id], - is_prompts=batch.is_prompts, - sequence_lengths=sequence_lengths, - attention_mask=attention_mask, - kv_seq_len=kv_seq_len, - cos_sin=cos_sin, - fd_inter_tensor=batch.fd_inter_tensor, - output_tensor=output_tensor, - norm_output=norm_output, - sm_scale=sm_scale, - ) - - if batch.is_prompts: - hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() - norm_output = torch.empty_like(hidden_states) - hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) - - return hidden_states - - -def llama_decoder_layer_forward( - self: LlamaDecoderLayer, - hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, - is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, - attention_mask: torch.Tensor = None, - kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, - output_tensor: torch.Tensor = None, - norm_output: torch.Tensor = None, - sm_scale: int = None, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """This function will replace the forward function of LlamaDecoderLayer. - - Args: - hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - position_ids (torch.LongTensor), The position ids of input sequences. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. - kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. - fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. - output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. - norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. - sm_scale (int, optional): Used for flash attention. Defaults to None. - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - block_tables=block_tables, - k_cache=k_cache, - v_cache=v_cache, - is_prompts=is_prompts, - sequence_lengths=sequence_lengths, - attention_mask=attention_mask, - kv_seq_len=kv_seq_len, - cos_sin=cos_sin, - fd_inter_tensor=fd_inter_tensor, - output_tensor=output_tensor, - sm_scale=sm_scale, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - -class PadLlamaAttention(LlamaAttention): - def __init__( - self, - config: LlamaConfig, - layer_idx: Optional[int] = None, - attn_qproj_w: torch.nn.Parameter = None, - attn_kproj_w: torch.nn.Parameter = None, - attn_vproj_w: torch.nn.Parameter = None, - attn_oproj_w: torch.nn.Parameter = None, - ): - """This layer will replace the LlamaAttention. - - Args: - config (LlamaConfig): Holding the Llama model config. - layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None. - attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None. - attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None. - attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None. - attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None. - """ - super().__init__(config, layer_idx) - self.q_proj.weight = attn_qproj_w - self.k_proj.weight = attn_kproj_w - self.v_proj.weight = attn_vproj_w - self.o_proj.weight = attn_oproj_w - - @staticmethod - def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: - """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention - - Args: - module (LlamaAttention): The origin LlamaAttention layer. - """ - config = module.config - layer_idx = module.layer_idx - - attn_qproj_w = module.q_proj.weight - attn_kproj_w = module.k_proj.weight - attn_vproj_w = module.v_proj.weight - attn_oproj_w = module.o_proj.weight - - attn_layer = PadLlamaAttention( - config=config, - layer_idx=layer_idx, - attn_qproj_w=attn_qproj_w, - attn_kproj_w=attn_kproj_w, - attn_vproj_w=attn_vproj_w, - attn_oproj_w=attn_oproj_w, - ) - - return attn_layer - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, - is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, - attention_mask: torch.Tensor = None, - kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, - output_tensor: torch.Tensor = None, - sm_scale: int = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Args: - hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim] - position_ids (torch.LongTensor), The position ids of input sequences. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. - attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size [batch_size, seq_len] - where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. - kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. - fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for - storing intermediate values in flash-decoding. Defaults to None. - output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. - sm_scale (int, optional): Used for flash attention. Defaults to None. - """ - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - if HAS_TRITON: - if is_prompts: - if attention_mask is not None: - query_states, key_states, value_states, indices = unpading_input( - query_states, key_states, value_states, attention_mask - ) - else: - query_states = query_states.view(-1, self.num_heads, self.head_dim) - key_states = key_states.view(-1, self.num_heads, self.head_dim) - value_states = value_states.view(-1, self.num_heads, self.head_dim) - else: - query_states = query_states.squeeze(dim=1) - key_states = key_states.squeeze(dim=1) - value_states = value_states.squeeze(dim=1) - - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - - block_size = k_cache.size(-2) - - if is_prompts: - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - ) - if attention_mask is not None: - attn_output = pad_input(attn_output, indices, bsz, q_len) - else: - copy_kv_to_blocked_cache( - key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables - ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - ) - attn_output = attn_output.squeeze(1) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if is_prompts: - attn_output = PagedAttention.pad_context_forward( - query_states, - key_states, - value_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - attention_mask, - ) - else: - attn_output = PagedAttention.pad_decoding_forward( - query_states, - key_states, - value_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - attention_mask, - ) - - attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output - - -def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: - """Generate padding position_id through attention mask. - - Args: - attention_mask (`torch.Tensor` of shape [batch_size, sequence_length]: - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - Returns: - torch.Tensor: The padding position_id. - """ - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - return position_ids - - -def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): - """Convert padding input to nopad input. - - Args: - q (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] - k (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] - v (torch.Tensor): [batch_size, q_seq_len, head_num, head_dim] - attention_mask (torch.Tensor): [batch_size, sequence_length] - - Returns: - Tuple[torch.Tensor]: The unpad q, k, v and The index of valid data in each batch. - - """ - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape - q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) - return (q, k, v, indices) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 8715f9981..8d41dff13 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -13,7 +13,7 @@ if HAS_TRITON: from .fused_rotary_embedding import fused_rotary_embedding from .gptq_triton import gptq_fused_linear_triton from .kvcache_copy import copy_kv_to_blocked_cache - from .no_pad_rotary_embedding import rotary_embedding + from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding from .rms_layernorm import rms_layernorm from .rotary_cache_copy import get_xine_cache from .softmax import softmax @@ -28,4 +28,5 @@ if HAS_TRITON: "rotary_embedding", "fused_rotary_embedding", "get_xine_cache", + "decoding_fused_rotary_embedding", ] diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 4f056acf6..96ab922e3 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -45,21 +45,21 @@ def _copy_to_kvcache_seqlen1_kernel( k = tl.load(K + offsets_kv) v = tl.load(V + offsets_kv) - offsets_kvcache = ( + offsets_kcache = ( block_id * stride_cachekb + cur_kv_head_idx * stride_cachekh + offsets_in_last_block * stride_cachekbs + offsets_dmodel * stride_cachekd ) - offsets_kvcache = ( + offsets_vcache = ( block_id * stride_cachevb + cur_kv_head_idx * stride_cachevh + offsets_in_last_block * stride_cachevbs + offsets_dmodel * stride_cachevd ) - tl.store(KCache + offsets_kvcache, k) - tl.store(VCache + offsets_kvcache, v) + tl.store(KCache + offsets_kcache, k) + tl.store(VCache + offsets_vcache, v) return diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 9194319d5..4b294a399 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel( out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim - past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1 + past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1 last_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride - block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride) + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens)) offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride kv_range0 = ( @@ -274,6 +274,241 @@ def fused_rotary_embedding_kernel( ) +@triton.jit +def fused_rotary_embedding_kernel_v2( + q, + k, + cos, + sin, + kv_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cacheb_stride, + cacheh_stride, + cachebs_stride, + cached_stride, + bts_stride, + btb_stride, + block_size, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + block_head_index = tl.program_id(0) + if block_head_index >= Q_HEAD_NUM: + return + block_token_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride + off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride + off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride + off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride + + loaded_q0 = tl.load( + q + off_q0, + ) + loaded_q1 = tl.load( + q + off_q1, + ) + + loaded_k0 = tl.load( + k + off_k0, + ) + + loaded_k1 = tl.load( + k + off_k1, + ) + + off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + + out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin + out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens)) + offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride + + kv_range0 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range0 * cached_stride + ) + kv_range1 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range1 * cached_stride + ) + + tl.store( + kv_cache + kv_range0, + out_k0, + ) + tl.store( + kv_cache + kv_range1, + out_k1, + ) + + # concat + tl.store( + q + off_q0, + out_q0, + ) + tl.store( + q + off_q1, + out_q1, + ) + + +@triton.jit +def decoding_fused_rotary_embedding_kernel( + q, + k, + v, + cos, + sin, + k_cache, + v_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cache_b_stride, + cache_h_stride, + cache_bs_stride, + cache_d_stride, + bts_stride, + btb_stride, + block_size, + Q_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + block_head_index = tl.program_id(0) + if block_head_index >= Q_HEAD_NUM: + return + + block_token_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + total_dim_range = tl.arange(0, HEAD_DIM) + + q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride + off_q0 = q_off_base + dim_range0 * head_dim_stride + off_q1 = q_off_base + dim_range1 * head_dim_stride + + off_base = block_token_index * k_token_stride + block_head_index * k_head_stride + off_k0 = off_base + dim_range0 * head_dim_stride + off_k1 = off_base + dim_range1 * head_dim_stride + + off_v = off_base + total_dim_range * head_dim_stride + + loaded_q0 = tl.load( + q + off_q0, + ) + loaded_q1 = tl.load( + q + off_q1, + ) + + loaded_k0 = tl.load( + k + off_k0, + ) + + loaded_k1 = tl.load( + k + off_k1, + ) + + loaded_v = tl.load( + v + off_v, + ) + + off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin) + loaded_sin = tl.load(sin + off_cos_sin) + + out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin + out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride) + offsets_in_last_block = past_kv_seq_len % block_size + + k_range0 = ( + block_ids * cache_b_stride + + block_head_index * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + dim_range0 * cache_d_stride + ) + k_range1 = ( + block_ids * cache_b_stride + + block_head_index * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + dim_range1 * cache_d_stride + ) + v_range = ( + block_ids * cache_b_stride + + block_head_index * cache_h_stride + + offsets_in_last_block * cache_bs_stride + + total_dim_range * cache_d_stride + ) + + tl.store( + v_cache + v_range, + loaded_v, + ) + + tl.store( + k_cache + k_range0, + out_k0, + ) + + tl.store( + k_cache + k_range1, + out_k1, + ) + + # concat + tl.store( + q + off_q0, + out_q0, + ) + tl.store( + q + off_q1, + out_q1, + ) + + def rotary_embedding( q: torch.Tensor, k: torch.Tensor, @@ -297,12 +532,13 @@ def rotary_embedding( assert q.size(0) == k.size(0) BLOCK_HEAD = 4 BLOCK_TOKENS = 4 - grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"])) - if head_dim >= 256: + if head_dim >= 1024: num_warps = 32 - elif head_dim >= 128: + elif head_dim >= 512: num_warps = 16 + elif head_dim >= 256: + num_warps = 8 else: num_warps = 4 @@ -318,6 +554,10 @@ def rotary_embedding( cos_token_stride = cos.stride(0) cos_stride = cos.stride(1) if k_cache == None: + grid = lambda META: ( + triton.cdiv(q_head_num, META["BLOCK_HEAD"]), + triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), + ) rotary_embedding_kernel[grid]( q, k, @@ -339,7 +579,8 @@ def rotary_embedding( num_warps=num_warps, ) else: - fused_rotary_embedding_kernel[grid]( + grid = (triton.next_power_of_2(q_head_num), q_total_tokens) + fused_rotary_embedding_kernel_v2[grid]( q, k, cos, @@ -363,10 +604,85 @@ def rotary_embedding( k_cache.size(-2), q_total_tokens, Q_HEAD_NUM=q_head_num, - K_HEAD_NUM=k_head_num, HEAD_DIM=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_TOKENS=BLOCK_TOKENS, num_warps=num_warps, ) return + + +def decoding_fused_rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_cache: Optional[torch.Tensor] = None, + v_cache: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + kv_lengths: Optional[torch.Tensor] = None, +): + """ + Args: + q: query tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, head_num, head_dim] + v: value tensor, [total tokens, head_num, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] + v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim] + kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] + block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] + """ + q_total_tokens, q_head_num, head_dim = q.shape + assert q.size(0) == k.size(0) == v.size(0) + assert q.size(1) == k.size(1) == v.size(1) + assert k_cache.size(-1) == v_cache.size(-1) + + if head_dim >= 1024: + num_warps = 32 + elif head_dim >= 512: + num_warps = 16 + elif head_dim >= 256: + num_warps = 8 + else: + num_warps = 4 + + q_token_stride = q.stride(0) + q_head_stride = q.stride(1) + head_dim_stride = q.stride(2) + + k_token_stride = k.stride(0) + k_head_stride = k.stride(1) + + cos_token_stride = cos.stride(0) + cos_stride = cos.stride(1) + grid = (triton.next_power_of_2(q_head_num), q_total_tokens) + decoding_fused_rotary_embedding_kernel[grid]( + q, + k, + v, + cos, + sin, + k_cache, + v_cache, + block_tables, + kv_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + k_cache.size(-2), + Q_HEAD_NUM=q_head_num, + HEAD_DIM=head_dim, + num_warps=num_warps, + ) + return diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 4665b4594..8098f4891 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -204,7 +204,7 @@ def benchmark_inference(args): torch.cuda.cudart().cudaProfilerStop() if args.profile: ctx.step() - + print(f"config:batch_size {args.batch_size}, input_len{ args.seq_len}, output_len {args.output_len}") print_details_info(model.config, args, whole_end2end, total_token_num) diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index c835a79df..9a68f86e2 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -1,7 +1,8 @@ ROOT=$(realpath $(dirname $0)) +echo $ROOT PY_SCRIPT=${ROOT}/benchmark_llama.py GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) -mode=$1 +mode="colossalai" mkdir -p logs @@ -23,10 +24,10 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU -for input_len in 128 512 1024; do +for input_len in 128 512 1024; do for output_len in 128 256; do for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --model_path "/home/caidi/llama_model/" | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt done done done diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py index 6a8dc85f0..d3f61325c 100644 --- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py @@ -3,8 +3,8 @@ import torch from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -from colossalai.kernel.triton import rotary_embedding -from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 +from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding +from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token try: import triton # noqa @@ -67,25 +67,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): ) new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths) + decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths) assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4) - assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4) - - # check one by one - for seq_i in range(BATCH_SIZE): - ki = new_k[seq_i] - ki = ki.squeeze() - past_kv_seq_len = kv_seq_lengths[seq_i] - 1 - target_block_id = block_tables[seq_i, past_kv_seq_len // block_size] - offsets_in_block = past_kv_seq_len % block_size - target = k_cache[target_block_id, :, offsets_in_block, :] - orig = new_k[seq_i].squeeze(dim=0) - assert torch.equal(orig, target) BATCH = 16 @@ -94,8 +83,8 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 11)], line_arg="provider", - line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], - line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], + line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], + line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", @@ -110,23 +99,53 @@ def benchmark_rotary_emb( num_tokens: int, num_kv_heads: int, ): + BATCH_SIZE = 4 + SEQ_LEN = num_tokens // BATCH_SIZE + max_num_blocks_per_seq = 8 + block_size = 64 warmup = 10 rep = 100 - head_dim = 128 + head_dim = 4096 dtype = torch.float16 + q_shape = (num_tokens, num_kv_heads, head_dim) q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") k_shape = (num_tokens, num_kv_heads, head_dim) k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + v = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (num_tokens, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") - if provider == "torch_rotary_emb_func": - fn = lambda: torch_rotary_emb(q, cos, sin) - elif provider == "triton_rotary_emb_func": - fn = lambda: rotary_embedding(q, k, cos, sin) + if provider == "no_fused_rotary_emb_func": + fn = lambda: [ + rotary_embedding(new_q, new_k, cos, sin), + copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables + ), + ] + elif provider == "fused_triton_rotary_emb_func": + fn = lambda: decoding_fused_rotary_embedding( + new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths + ) else: raise ValueError("Undefined provider") @@ -136,4 +155,4 @@ def benchmark_rotary_emb( if __name__ == "__main__": test_rotary_emb(4, 64, 32, 64, torch.float32) - # benchmark_rotary_emb.run(save_path=".",print_data=True) + # benchmark_rotary_emb.run(save_path=".", print_data=True)