From 249644c23b0402ccf9d0908f13ed15b41b95145f Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 1 Feb 2024 15:49:39 +0800 Subject: [PATCH] =?UTF-8?q?[Inference]Repalce=20Attention=20layer=20and=20?= =?UTF-8?q?MLP=20layer=20by=20shardformer=20to=20optimize=20the=20weight?= =?UTF-8?q?=20transpose=20operation=EF=BC=8Cadd=20fused=5Fqkv=20and=20fuse?= =?UTF-8?q?d=20linear=5Fadd=20(#5340)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py --- .../modeling/models/nopadding_llama.py | 306 +++++++++++++---- .../modeling/models/padding_llama.py | 321 ++++++++++++------ .../modeling/policy/nopadding_llama.py | 60 ++-- .../modeling/policy/padding_llama.py | 135 +------- colossalai/kernel/triton/flash_decoding.py | 10 +- examples/inference/run_benchmark.sh | 14 +- tests/test_infer_ops/triton/kernel_utils.py | 1 + .../triton/test_decoding_attn.py | 4 +- 8 files changed, 510 insertions(+), 341 deletions(-) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 569c5f05a..6b108cd4d 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -2,8 +2,10 @@ from typing import List, Optional, Tuple import torch +from torch.nn import Parameter from transformers.models.llama.modeling_llama import ( LlamaAttention, + LlamaConfig, LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, @@ -39,6 +41,14 @@ def llama_causal_lm_forward( k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): + """This function will replace the forward function of LlamaForCausalLM. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = llama_model_forward( self.model, @@ -46,7 +56,7 @@ def llama_causal_lm_forward( k_caches=k_caches, v_caches=v_caches, ) - logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1)) + logits = torch.mm(hidden_states, self.lm_head.weight) return logits @@ -57,6 +67,13 @@ def llama_model_forward( k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): + """This function will replace the forward function of LlamaModel. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ input_ids = batch.get_1D_inputs() block_tables = batch.get_block_table_tensor() @@ -74,7 +91,7 @@ def llama_model_forward( ) else: output_tensor = torch.zeros( - (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device ) sm_scale = 1.0 / (batch.head_dim**0.5) @@ -116,12 +133,30 @@ def llama_decoder_layer_forward( output_tensor: torch.Tensor = None, sm_scale: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """This function will replace the forward function of LlamaDecoderLayer. + + Args: + hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, + residual=residual, block_tables=block_tables, k_cache=k_cache, v_cache=v_cache, @@ -134,88 +169,213 @@ def llama_decoder_layer_forward( sm_scale=sm_scale, ) - hidden_states = residual + hidden_states - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states = self.mlp(hidden_states, residual) return hidden_states -# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward -@torch.no_grad() -def llama_attn_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, - is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, - kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, - output_tensor: torch.Tensor = None, - sm_scale: int = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim) - key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view( - -1, self.num_key_value_heads, self.head_dim - ) - value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view( - -1, self.num_key_value_heads, self.head_dim - ) +class NopadLlamaAttention(LlamaAttention): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + attn_qproj_w: torch.Tensor = None, + attn_kproj_w: torch.Tensor = None, + attn_vproj_w: torch.Tensor = None, + attn_oproj_w: torch.Tensor = None, + ): + """This layer will replace the LlamaAttention. - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + Args: + config (LlamaConfig): Holding the Llama model config. + layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None. + attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. + attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. + attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. + attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None. + """ + super().__init__(config, layer_idx) + self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False) + self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False) + self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False) + self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False) + if self.num_heads == self.num_key_value_heads: + qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight] + self.qkv_weight = torch.stack(qkv_weight_list, dim=0) + self.q_proj = None + self.k_proj = None + self.v_proj = None - block_size = k_cache.size(-2) + @staticmethod + def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: + """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention. - if is_prompts: - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, + Args: + module (LlamaAttention): The origin LlamaAttention layer. + """ + config = module.config + layer_idx = module.layer_idx + + attn_qproj_w = module.q_proj.weight.transpose(0, 1) + attn_kproj_w = module.k_proj.weight.transpose(0, 1) + attn_vproj_w = module.v_proj.weight.transpose(0, 1) + attn_oproj_w = module.o_proj.weight.transpose(0, 1) + + attn_layer = NopadLlamaAttention( + config=config, + layer_idx=layer_idx, + attn_qproj_w=attn_qproj_w, + attn_kproj_w=attn_kproj_w, + attn_vproj_w=attn_vproj_w, + attn_oproj_w=attn_oproj_w, ) - else: - copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, + + return attn_layer + + # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)` + residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ + + if self.num_heads != self.num_key_value_heads: + query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim) + key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) + value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) + else: + # fused qkv + token_nums = hidden_states.size(0) + hidden_states = hidden_states.expand(3, -1, -1) + query_states, key_states, value_states = ( + torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) + ) + + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + + block_size = k_cache.size(-2) + + if is_prompts: + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + else: + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + ) + + attn_output = attn_output.reshape(-1, self.hidden_size) + attn_output = torch.addmm(residual, attn_output, self.o_proj.weight) + + return attn_output + + +# NOTE This will cause the result to be different from the transformer in some cases. +class NopadLlamaMLP(LlamaMLP): + def __init__( + self, + config: LlamaConfig, + mlp_gproj_w: torch.Tensor = None, + mlp_uproj_w: torch.Tensor = None, + mlp_dproj_w: torch.Tensor = None, + ): + """This layer will replace the LlamaAttention. + + Args: + config (LlamaConfig): Holding the Llama model config. + mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None. + mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None. + mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. + """ + super().__init__(config) + self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False) + self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False) + self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False) + + @staticmethod + def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: + """Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP. + + Args: + module (LlamaMLP): The origin LlamaMLP layer. + """ + config = module.config + + mlp_gproj_w = module.gate_proj.weight.transpose(0, 1) + mlp_uproj_w = module.up_proj.weight.transpose(0, 1) + mlp_dproj_w = module.down_proj.weight.transpose(0, 1) + + mlp_layer = NopadLlamaMLP( + config=config, + mlp_gproj_w=mlp_gproj_w, + mlp_uproj_w=mlp_uproj_w, + mlp_dproj_w=mlp_dproj_w, ) - attn_output = attn_output.squeeze(1) - attn_output = attn_output.view(-1, self.num_heads, self.head_dim) - attn_output = attn_output.reshape(-1, self.hidden_size) - attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1)) + return mlp_layer - return attn_output - - -@torch.no_grad() -def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor): - gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1)) - act_out = torch.nn.functional.silu(gate_proj_out, inplace=True) - up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1)) - tmp_out = act_out * up_proj_out - return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1)) + @torch.no_grad() + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`. + residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj. + """ + gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight) + act_out = torch.nn.functional.silu(gate_proj_out, inplace=True) + up_proj_out = torch.mm(hidden_states, self.up_proj.weight) + tmp_out = act_out * up_proj_out + return torch.addmm(residual, tmp_out, self.down_proj.weight) diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py index 63a8d3673..51d718a53 100644 --- a/colossalai/inference/modeling/models/padding_llama.py +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -2,7 +2,13 @@ from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.layers.attention import PagedAttention @@ -53,6 +59,14 @@ def llama_causal_lm_forward( k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): + """This function will replace the forward function of LlamaForCausalLM. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = llama_model_forward( self.model, @@ -71,6 +85,13 @@ def llama_model_forward( k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, ): + """This function will replace the forward function of LlamaModel. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + """ input_ids = batch.get_batch_inputs() block_tables = batch.get_block_table_tensor() attention_mask = batch.get_attn_mask() @@ -110,7 +131,7 @@ def llama_model_forward( ) else: output_tensor = torch.zeros( - (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device ) sm_scale = 1.0 / (batch.head_dim**0.5) @@ -131,7 +152,8 @@ def llama_model_forward( sm_scale=sm_scale, ) - hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() + if batch.is_prompts: + hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() hidden_states = self.norm(hidden_states) return hidden_states @@ -154,6 +176,23 @@ def llama_decoder_layer_forward( output_tensor: torch.Tensor = None, sm_scale: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """This function will replace the forward function of LlamaDecoderLayer. + + Args: + hidden_states (torch.Tensor): _description_ + position_ids (torch.LongTensor), The position ids of input sequences. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -185,108 +224,192 @@ def llama_decoder_layer_forward( return hidden_states -# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward -@torch.no_grad() -def llama_attn_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, - is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, - attention_mask: torch.Tensor = None, - kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, - output_tensor: torch.Tensor = None, - sm_scale: int = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() +class PadLlamaAttention(LlamaAttention): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + attn_qproj_w: torch.nn.Parameter = None, + attn_kproj_w: torch.nn.Parameter = None, + attn_vproj_w: torch.nn.Parameter = None, + attn_oproj_w: torch.nn.Parameter = None, + ): + """This layer will replace the LlamaAttention. - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + Args: + config (LlamaConfig): Holding the Llama model config. + layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None. + attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None. + attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None. + attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None. + attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None. + """ + super().__init__(config, layer_idx) + self.q_proj.weight = attn_qproj_w + self.k_proj.weight = attn_kproj_w + self.v_proj.weight = attn_vproj_w + self.o_proj.weight = attn_oproj_w - if HAS_TRITON: - if is_prompts: - if attention_mask is not None: - query_states, key_states, value_states, indices = unpading_input( - query_states, key_states, value_states, attention_mask + @staticmethod + def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention: + """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention + + Args: + module (LlamaAttention): The origin LlamaAttention layer. + """ + config = module.config + layer_idx = module.layer_idx + + attn_qproj_w = module.q_proj.weight + attn_kproj_w = module.k_proj.weight + attn_vproj_w = module.v_proj.weight + attn_oproj_w = module.o_proj.weight + + attn_layer = PadLlamaAttention( + config=config, + layer_idx=layer_idx, + attn_qproj_w=attn_qproj_w, + attn_kproj_w=attn_kproj_w, + attn_vproj_w=attn_vproj_w, + attn_oproj_w=attn_oproj_w, + ) + + return attn_layer + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + attention_mask: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)` + position_ids (torch.LongTensor), The position ids of input sequences. + block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. Defaults to None. + k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. + attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` + where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. Defaults to None. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + if HAS_TRITON: + if is_prompts: + if attention_mask is not None: + query_states, key_states, value_states, indices = unpading_input( + query_states, key_states, value_states, attention_mask + ) + else: + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + else: + query_states = query_states.squeeze(dim=1) + key_states = key_states.squeeze(dim=1) + value_states = value_states.squeeze(dim=1) + + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + + block_size = k_cache.size(-2) + + if is_prompts: + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + if attention_mask is not None: + attn_output = pad_input(attn_output, indices, bsz, q_len) + else: + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + ) + attn_output = attn_output.squeeze(1) + else: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if is_prompts: + attn_output = PagedAttention.pad_context_forward( + query_states, + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + attention_mask, ) else: - query_states = query_states.view(-1, self.num_heads, self.head_dim) - key_states = key_states.view(-1, self.num_heads, self.head_dim) - value_states = value_states.view(-1, self.num_heads, self.head_dim) - else: - query_states = query_states.squeeze(dim=1) - key_states = key_states.squeeze(dim=1) - value_states = value_states.squeeze(dim=1) + attn_output = PagedAttention.pad_decoding_forward( + query_states, + key_states, + value_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + attention_mask, + ) - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - block_size = k_cache.size(-2) - - if is_prompts: - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - ) - if attention_mask is not None: - attn_output = pad_input(attn_output, indices, bsz, q_len) - else: - copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - ) - attn_output = attn_output.squeeze(1) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if is_prompts: - attn_output = PagedAttention.pad_context_forward( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask - ) - else: - attn_output = PagedAttention.pad_decoding_forward( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask - ) - - attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output + return attn_output @torch.no_grad() diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 3eaa59f74..aed72ef73 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,25 +1,18 @@ from functools import partial import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaForCausalLM, - LlamaMLP, - LlamaModel, - LlamaRMSNorm, - LlamaSdpaAttention, -) +from torch.nn import Parameter +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm from colossalai.inference.modeling.models.nopadding_llama import ( - llama_attn_forward, + NopadLlamaAttention, + NopadLlamaMLP, llama_causal_lm_forward, llama_decoder_layer_forward, llama_model_forward, - nopad_mlp, ) from colossalai.inference.utils import init_to_get_rotary +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -50,6 +43,27 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def module_policy(self): policy = super().module_policy() + + decoder_attribute_replacement = { + "lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False), + } + policy[LlamaForCausalLM] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="mlp", + target_module=NopadLlamaMLP, + ), + SubModuleReplacementDescription( + suffix="self_attn", + target_module=NopadLlamaAttention, + ), + ] + ) + self.shard_config._infer() infer_forward = llama_causal_lm_forward @@ -68,28 +82,6 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): description=method_replacement, policy=policy, target_key=LlamaDecoderLayer ) - infer_forward = nopad_mlp - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP) - - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaAttention - ) - - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaFlashAttention2 - ) - - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaSdpaAttention - ) - infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() diff --git a/colossalai/inference/modeling/policy/padding_llama.py b/colossalai/inference/modeling/policy/padding_llama.py index 0c83189f8..9aa64f55b 100644 --- a/colossalai/inference/modeling/policy/padding_llama.py +++ b/colossalai/inference/modeling/policy/padding_llama.py @@ -1,18 +1,10 @@ from functools import partial import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, - LlamaSdpaAttention, -) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm from colossalai.inference.modeling.models.padding_llama import ( - llama_attn_forward, + PadLlamaAttention, llama_causal_lm_forward, llama_decoder_layer_forward, llama_model_forward, @@ -49,105 +41,16 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def module_policy(self): policy = super().module_policy() - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads - // self.shard_config.tensor_parallel_size, - } - if self.shard_config.extra_kwargs.get("quant", None) == "gptq": - from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - ], - ) + policy[LlamaDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn", + target_module=PadLlamaAttention, + ), + ] + ) - elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant": - from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer - from colossalai.inference.quant.smoothquant.models.parallel_linear import ( - ColW8A8BFP32OFP32Linear, - RowW8A8B8O8Linear, - RowW8A8BFP32O32LinearSiLU, - RowW8A8BFP32OFP32Linear, - ) - - policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=ColW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=RowW8A8BFP32O32LinearSiLU, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=RowW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=ColW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - ], - ) self.shard_config._infer() infer_forward = llama_causal_lm_forward @@ -166,24 +69,6 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): description=method_replacement, policy=policy, target_key=LlamaDecoderLayer ) - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaAttention - ) - - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaFlashAttention2 - ) - - infer_forward = llama_attn_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaSdpaAttention - ) - infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 4bba24503..37fcd504c 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel( stride_o_lset, stride_o_lseh, stride_o_lseb, - stride_ob, - stride_ol, + stride_ot, stride_oh, stride_od, BLOCK_KV: tl.constexpr, @@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel( m_i = m_ij acc = acc / l - offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel + offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel tl.store(O + offsets_O, acc.to(O.type.element_ty)) return @@ -212,7 +211,7 @@ def flash_decoding_attention( records the (kv) sequence lengths incorporating past kv sequence lengths. block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] max_seq_len_in_batch (int): Maximum sequence length in the batch. - output (torch.Tensor): [bsz, 1, num_heads, head_dim] + output (torch.Tensor): [bsz, num_heads, head_dim] mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] @@ -294,7 +293,7 @@ def flash_decoding_attention( HEAD_DIM=head_dim, ) - output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output + output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output grid = (triton.next_power_of_2(bsz), num_heads) @@ -314,7 +313,6 @@ def flash_decoding_attention( output.stride(0), output.stride(1), output.stride(2), - output.stride(3), BLOCK_KV=block_size, HEAD_DIM=head_dim, ) diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index bdd79836e..6870ed384 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -25,10 +25,20 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt done for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt +done + + +for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt +done + + +for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt done diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index 7c3bc5ca6..22167ded0 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -69,6 +69,7 @@ def torch_attn_ref( f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}" ) out = out.transpose(1, 2).contiguous() + out = out.squeeze(1) return out diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py index a49ee3146..5eac026bb 100644 --- a/tests/test_infer_ops/triton/test_decoding_attn.py +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -94,7 +94,7 @@ def test_flash_decoding( max_seq_len_in_b = kv_seq_lengths.max().item() # The maximum block length splitted on kv should be the kv cache block size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size - output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) + output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) mid_output = torch.empty( size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device ) @@ -189,7 +189,7 @@ def bench_kernel( block_tables = block_tables.to(device=device) # the maximum block length splitted on kv should be the kv cache block size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size - output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) + output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) mid_output = torch.empty( size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device )