2024-01-30 02:31:46 +00:00
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List , Optional , Tuple
import torch
2024-03-25 05:40:34 +00:00
import torch . nn . functional as F
2024-01-30 02:31:46 +00:00
from transformers . models . llama . modeling_llama import (
LlamaAttention ,
2024-02-01 07:49:39 +00:00
LlamaConfig ,
2024-01-30 02:31:46 +00:00
LlamaDecoderLayer ,
LlamaForCausalLM ,
LlamaMLP ,
LlamaModel ,
2024-03-08 08:21:12 +00:00
LlamaRMSNorm ,
2024-01-30 02:31:46 +00:00
)
2024-03-08 06:19:35 +00:00
from colossalai . inference . config import InputMetaData
2024-01-30 02:31:46 +00:00
from colossalai . inference . flash_decoding_utils import FDIntermTensors
2024-02-28 06:36:50 +00:00
from colossalai . kernel . kernel_loader import InferenceOpsLoader
2024-01-30 02:31:46 +00:00
from colossalai . kernel . triton import (
context_attention_unpadded ,
2024-02-21 03:31:48 +00:00
decoding_fused_rotary_embedding ,
2024-01-30 02:31:46 +00:00
flash_decoding_attention ,
get_xine_cache ,
2024-03-08 08:21:12 +00:00
rms_layernorm ,
2024-01-30 02:31:46 +00:00
rotary_embedding ,
)
from colossalai . logging import get_dist_logger
2024-02-28 06:36:50 +00:00
inference_ops = InferenceOpsLoader ( ) . load ( )
2024-01-30 02:31:46 +00:00
logger = get_dist_logger ( __name__ )
try :
2024-03-25 05:40:34 +00:00
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
2024-01-30 02:31:46 +00:00
except ImportError :
2024-03-25 05:40:34 +00:00
use_flash_attn2 = False
logger . warning ( f " flash_attn2 has not been installed yet, we will use triton flash attn instead. " )
2024-01-30 02:31:46 +00:00
def llama_causal_lm_forward (
self : LlamaForCausalLM ,
2024-03-08 06:19:35 +00:00
input_tokens_ids : torch . Tensor ,
output_tensor : torch . Tensor ,
inputmetadata : InputMetaData ,
2024-01-30 02:31:46 +00:00
k_caches : List [ torch . Tensor ] = None ,
v_caches : List [ torch . Tensor ] = None ,
2024-03-08 06:19:35 +00:00
) - > torch . Tensor :
2024-02-01 07:49:39 +00:00
""" This function will replace the forward function of LlamaForCausalLM.
Args :
2024-03-25 05:40:34 +00:00
batch ( BatchInfo ) : It stores the necessary input information for this inference .
k_caches ( List [ torch . Tensor ] ) : It holds the GPU memory for the key cache .
v_caches ( List [ torch . Tensor ] ) : It holds the GPU memory for the value cache .
high_precision ( Optional [ bool ] ) : Whether to use float32 for underlying calculations of float16 data to achieve higher precision , defaults to False .
2024-02-01 07:49:39 +00:00
"""
2024-01-30 02:31:46 +00:00
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward (
self . model ,
2024-03-08 06:19:35 +00:00
input_tokens_ids = input_tokens_ids ,
output_tensor = output_tensor ,
inputmetadata = inputmetadata ,
2024-01-30 02:31:46 +00:00
k_caches = k_caches ,
v_caches = v_caches ,
2024-03-19 05:24:25 +00:00
use_cuda_kernel = inputmetadata . use_cuda_kernel , # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
2024-03-25 06:48:28 +00:00
high_precision = inputmetadata . high_precision ,
2024-01-30 02:31:46 +00:00
)
2024-02-01 07:49:39 +00:00
logits = torch . mm ( hidden_states , self . lm_head . weight )
2024-01-30 02:31:46 +00:00
return logits
def llama_model_forward (
self : LlamaModel ,
2024-03-08 06:19:35 +00:00
input_tokens_ids : torch . Tensor ,
output_tensor : torch . Tensor ,
inputmetadata : InputMetaData ,
2024-01-30 02:31:46 +00:00
k_caches : List [ torch . Tensor ] = None ,
v_caches : List [ torch . Tensor ] = None ,
2024-03-19 05:24:25 +00:00
use_cuda_kernel : Optional [ bool ] = True ,
2024-03-25 05:40:34 +00:00
high_precision : bool = False ,
2024-03-08 06:19:35 +00:00
) - > torch . Tensor :
2024-02-01 07:49:39 +00:00
""" This function will replace the forward function of LlamaModel.
Args :
2024-03-25 05:40:34 +00:00
batch ( BatchInfo ) : It stores the necessary input information for this inference .
k_caches ( List [ torch . Tensor ] ) : It holds the GPU memory for the key cache .
v_caches ( List [ torch . Tensor ] ) : It holds the GPU memory for the value cache .
high_precision ( Optional [ bool ] ) : Whether to use float32 for underlying calculations of float16 data to achieve higher precision , defaults to False .
2024-02-01 07:49:39 +00:00
"""
2024-03-08 06:19:35 +00:00
block_tables = inputmetadata . block_tables
sequence_lengths = inputmetadata . sequence_lengths
batch_size = inputmetadata . batch_size
kv_seq_len = inputmetadata . kv_seq_len
2024-03-19 05:24:25 +00:00
2024-02-28 06:36:50 +00:00
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size > = 32 and kv_seq_len > 512 :
use_cuda_kernel = False
2024-01-30 02:31:46 +00:00
2024-03-08 06:19:35 +00:00
hidden_states = self . embed_tokens ( input_tokens_ids )
2024-03-25 06:48:28 +00:00
if use_cuda_kernel and inputmetadata != torch . float32 and use_flash_attn2 :
2024-03-25 05:40:34 +00:00
cu_seqlens = F . pad ( torch . cumsum ( sequence_lengths , dim = 0 , dtype = torch . torch . int32 ) , ( 1 , 0 ) )
else :
cu_seqlens = None
2024-01-30 02:31:46 +00:00
2024-03-08 06:19:35 +00:00
cos_sin = get_xine_cache ( sequence_lengths , self . _cos_cached , self . _sin_cached , inputmetadata . is_prompts )
2024-01-30 02:31:46 +00:00
2024-03-08 06:19:35 +00:00
sm_scale = 1.0 / ( inputmetadata . head_dim * * 0.5 )
2024-01-30 02:31:46 +00:00
2024-03-14 08:13:00 +00:00
norm_output = torch . empty_like ( hidden_states )
2024-02-21 05:23:57 +00:00
residual = None
2024-02-02 07:06:01 +00:00
2024-01-30 02:31:46 +00:00
for layer_id , decoder_layer in enumerate ( self . layers ) :
2024-02-21 05:23:57 +00:00
hidden_states , residual = decoder_layer (
2024-01-30 02:31:46 +00:00
hidden_states ,
2024-02-21 05:23:57 +00:00
residual = residual ,
2024-01-30 02:31:46 +00:00
block_tables = block_tables ,
k_cache = k_caches [ layer_id ] ,
v_cache = v_caches [ layer_id ] ,
2024-03-08 06:19:35 +00:00
is_prompts = inputmetadata . is_prompts ,
2024-01-30 02:31:46 +00:00
sequence_lengths = sequence_lengths ,
cos_sin = cos_sin ,
2024-03-08 06:19:35 +00:00
fd_inter_tensor = inputmetadata . fd_inter_tensor ,
2024-03-25 05:40:34 +00:00
kv_seq_len = kv_seq_len ,
2024-01-30 02:31:46 +00:00
output_tensor = output_tensor ,
2024-02-02 07:06:01 +00:00
norm_output = norm_output ,
2024-01-30 02:31:46 +00:00
sm_scale = sm_scale ,
2024-02-28 06:36:50 +00:00
use_cuda_kernel = use_cuda_kernel ,
2024-03-25 05:40:34 +00:00
cu_seqlens = cu_seqlens ,
high_precision = high_precision ,
2024-01-30 02:31:46 +00:00
)
2024-03-08 06:19:35 +00:00
if inputmetadata . is_prompts :
2024-01-30 02:31:46 +00:00
last_token_indexs = sequence_lengths . cumsum ( dim = - 1 )
hidden_states = hidden_states [ last_token_indexs - 1 ] . contiguous ( )
2024-02-21 05:23:57 +00:00
residual = residual [ last_token_indexs - 1 ] . contiguous ( )
2024-03-14 08:13:00 +00:00
norm_output = torch . empty_like ( hidden_states )
2024-03-08 08:21:12 +00:00
hidden_states , _ = self . norm ( hidden_states , norm_output , residual , use_cuda_kernel )
2024-01-30 02:31:46 +00:00
return hidden_states
def llama_decoder_layer_forward (
self : LlamaDecoderLayer ,
hidden_states : torch . Tensor ,
2024-02-21 05:23:57 +00:00
residual : torch . Tensor ,
2024-03-25 05:40:34 +00:00
block_tables : torch . Tensor ,
k_cache : torch . Tensor ,
v_cache : torch . Tensor ,
sequence_lengths : torch . Tensor ,
cos_sin : Tuple [ torch . Tensor ] ,
fd_inter_tensor : FDIntermTensors ,
2024-01-30 02:31:46 +00:00
is_prompts : bool = True ,
kv_seq_len : int = 0 ,
output_tensor : torch . Tensor = None ,
2024-02-02 07:06:01 +00:00
norm_output : torch . Tensor = None ,
2024-01-30 02:31:46 +00:00
sm_scale : int = None ,
2024-02-28 06:36:50 +00:00
use_cuda_kernel : bool = True ,
2024-03-25 05:40:34 +00:00
cu_seqlens : torch . Tensor = None ,
high_precision : bool = False ,
2024-01-30 02:31:46 +00:00
) - > Tuple [ torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ] ] :
2024-02-01 07:49:39 +00:00
""" This function will replace the forward function of LlamaDecoderLayer.
Args :
2024-02-06 11:38:25 +00:00
hidden_states ( torch . Tensor ) : input to the layer of shape [ token_num , embed_dim ] .
2024-02-21 05:23:57 +00:00
residual ( torch . Tensor ) : shape [ token_num , embed_dim ] , used to be added to hidden_states in out_proj .
2024-03-25 05:40:34 +00:00
block_tables ( torch . Tensor ) : A 2 D tensor of shape [ batch_size , max_blocks_per_sequence ] ,
storing mapping of token_position_id - > block_id .
k_cache ( torch . Tensor ) : It holds the GPU memory for the key cache .
v_cache ( torch . Tensor ) : It holds the GPU memory for the key cache .
sequence_lengths ( torch . Tensor ) : Holding the sequence length of each sequence .
cos_sin ( Tuple [ torch . Tensor ] ) : Holding cos and sin .
fd_inter_tensor ( FDIntermTensors ) : Holding tensors used for
storing intermediate values in flash - decoding .
2024-02-01 07:49:39 +00:00
is_prompts ( bool , optional ) : Whether the current inference process is in the context input phase . Defaults to True .
kv_seq_len ( int , optional ) : The max sequence length of input sequences . Defaults to 0.
output_tensor ( torch . Tensor , optional ) : The mid tensor holds the output of attention . Defaults to None .
2024-02-02 07:06:01 +00:00
norm_output ( torch . Tensor , optional ) : The mid tensor holds the output of layernorm . Defaults to None .
2024-02-01 07:49:39 +00:00
sm_scale ( int , optional ) : Used for flash attention . Defaults to None .
2024-02-28 06:36:50 +00:00
use_cuda_kernel : ( bool , optional ) : Whether to use cuda kernel . Defaults to True .
2024-03-25 05:40:34 +00:00
cu_seqlens ( torch . Tensor , optional ) : Holding the cumulative sum of sequence length .
high_precision ( Optional [ bool ] ) : Whether to use float32 for underlying calculations of float16 data to achieve higher precision , defaults to False .
2024-02-01 07:49:39 +00:00
"""
2024-01-30 02:31:46 +00:00
2024-03-08 08:21:12 +00:00
hidden_states , residual = self . input_layernorm ( hidden_states , norm_output , residual , use_cuda_kernel )
2024-01-30 02:31:46 +00:00
# Self Attention
hidden_states = self . self_attn (
hidden_states = hidden_states ,
block_tables = block_tables ,
k_cache = k_cache ,
v_cache = v_cache ,
sequence_lengths = sequence_lengths ,
cos_sin = cos_sin ,
fd_inter_tensor = fd_inter_tensor ,
2024-03-25 05:40:34 +00:00
is_prompts = is_prompts ,
kv_seq_len = kv_seq_len ,
2024-01-30 02:31:46 +00:00
output_tensor = output_tensor ,
sm_scale = sm_scale ,
2024-02-28 06:36:50 +00:00
use_cuda_kernel = use_cuda_kernel ,
2024-03-25 05:40:34 +00:00
cu_seqlens = cu_seqlens ,
high_precision = high_precision ,
2024-01-30 02:31:46 +00:00
)
# Fully Connected
2024-03-08 08:21:12 +00:00
hidden_states , residual = self . post_attention_layernorm ( hidden_states , norm_output , residual , use_cuda_kernel )
2024-02-21 05:23:57 +00:00
hidden_states = self . mlp ( hidden_states )
2024-01-30 02:31:46 +00:00
2024-02-21 05:23:57 +00:00
return hidden_states , residual
2024-01-30 02:31:46 +00:00
2024-03-08 08:21:12 +00:00
def llama_rmsnorm_forward (
self : LlamaRMSNorm ,
hidden_states : torch . Tensor ,
norm_output : torch . Tensor ,
residual : torch . Tensor = None ,
use_cuda_kernel : bool = True ,
) :
2024-03-21 03:28:42 +00:00
if use_cuda_kernel :
2024-03-08 08:21:12 +00:00
if residual is not None :
inference_ops . fused_add_rms_layernorm ( hidden_states , residual , self . weight . data , self . variance_epsilon )
return hidden_states , residual
if norm_output is None :
norm_output = torch . empty_like ( hidden_states )
inference_ops . rms_layernorm ( norm_output , hidden_states , self . weight . data , self . variance_epsilon )
return norm_output , hidden_states
else :
return rms_layernorm ( hidden_states , self . weight . data , self . variance_epsilon , norm_output , residual )
2024-02-01 07:49:39 +00:00
class NopadLlamaAttention ( LlamaAttention ) :
def __init__ (
self ,
config : LlamaConfig ,
layer_idx : Optional [ int ] = None ,
attn_qproj_w : torch . Tensor = None ,
attn_kproj_w : torch . Tensor = None ,
attn_vproj_w : torch . Tensor = None ,
attn_oproj_w : torch . Tensor = None ,
) :
""" This layer will replace the LlamaAttention.
2024-01-30 02:31:46 +00:00
2024-02-01 07:49:39 +00:00
Args :
config ( LlamaConfig ) : Holding the Llama model config .
layer_idx ( Optional [ int ] , optional ) : The decode layer id of this attention layer . Defaults to None .
attn_qproj_w ( torch . Tensor , optional ) : The transposed q_proj weight . Defaults to None .
attn_kproj_w ( torch . Tensor , optional ) : The transposed k_proj weight . Defaults to None .
attn_vproj_w ( torch . Tensor , optional ) : The transposed v_proj weight . Defaults to None .
attn_oproj_w ( torch . Tensor , optional ) : The transposed o_proj weight . Defaults to None .
"""
super ( ) . __init__ ( config , layer_idx )
2024-02-21 05:23:57 +00:00
self . q_proj_weight = attn_qproj_w
self . k_proj_weight = attn_kproj_w
self . v_proj_weight = attn_vproj_w
self . o_proj_weight = attn_oproj_w
2024-02-01 07:49:39 +00:00
if self . num_heads == self . num_key_value_heads :
2024-02-21 05:23:57 +00:00
qkv_weight_list = [ self . q_proj_weight , self . k_proj_weight , self . v_proj_weight ]
2024-02-01 07:49:39 +00:00
self . qkv_weight = torch . stack ( qkv_weight_list , dim = 0 )
2024-02-21 05:23:57 +00:00
self . q_proj = None
self . k_proj = None
self . v_proj = None
2024-01-30 02:31:46 +00:00
2024-02-01 07:49:39 +00:00
@staticmethod
def from_native_module ( module : LlamaAttention , * args , * * kwargs ) - > LlamaAttention :
""" Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
2024-01-30 02:31:46 +00:00
2024-02-01 07:49:39 +00:00
Args :
module ( LlamaAttention ) : The origin LlamaAttention layer .
"""
config = module . config
layer_idx = module . layer_idx
attn_qproj_w = module . q_proj . weight . transpose ( 0 , 1 )
attn_kproj_w = module . k_proj . weight . transpose ( 0 , 1 )
attn_vproj_w = module . v_proj . weight . transpose ( 0 , 1 )
attn_oproj_w = module . o_proj . weight . transpose ( 0 , 1 )
attn_layer = NopadLlamaAttention (
config = config ,
layer_idx = layer_idx ,
attn_qproj_w = attn_qproj_w ,
attn_kproj_w = attn_kproj_w ,
attn_vproj_w = attn_vproj_w ,
attn_oproj_w = attn_oproj_w ,
2024-01-30 02:31:46 +00:00
)
2024-02-01 07:49:39 +00:00
return attn_layer
2024-01-30 02:31:46 +00:00
2024-02-01 07:49:39 +00:00
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
def forward (
self ,
hidden_states : torch . Tensor ,
2024-03-25 05:40:34 +00:00
block_tables : torch . Tensor ,
k_cache : torch . Tensor ,
v_cache : torch . Tensor ,
sequence_lengths : torch . Tensor ,
cos_sin : Tuple [ torch . Tensor ] ,
fd_inter_tensor : FDIntermTensors ,
2024-02-01 07:49:39 +00:00
is_prompts : bool = True ,
kv_seq_len : int = 0 ,
output_tensor : torch . Tensor = None ,
sm_scale : int = None ,
2024-02-28 06:36:50 +00:00
use_cuda_kernel : bool = True ,
2024-03-25 05:40:34 +00:00
cu_seqlens : torch . Tensor = None ,
high_precision : bool = False ,
2024-02-01 07:49:39 +00:00
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
"""
Args :
2024-02-06 11:38:25 +00:00
hidden_states ( torch . Tensor ) : input to the layer of shape [ token_num , embed_dim ] .
2024-03-25 05:40:34 +00:00
block_tables ( torch . Tensor ) : A 2 D tensor of shape [ batch_size , max_blocks_per_sequence ] ,
storing mapping of token_position_id - > block_id .
k_cache ( torch . Tensor ) : It holds the GPU memory for the key cache .
v_cache ( torch . Tensor ) : It holds the GPU memory for the key cache .
sequence_lengths ( torch . Tensor , optional ) : Holding the sequence length of each sequence .
cos_sin ( Tuple [ torch . Tensor ] , optional ) : Holding cos and sin .
fd_inter_tensor ( FDIntermTensors , optional ) : Holding tensors used for
storing intermediate values in flash - decoding .
2024-02-01 07:49:39 +00:00
is_prompts ( bool , optional ) : Whether the current inference process is in the context input phase . Defaults to True .
kv_seq_len ( int , optional ) : The max sequence length of input sequences . Defaults to 0.
output_tensor ( torch . Tensor , optional ) : The mid tensor holds the output of attention . Defaults to None .
sm_scale ( int , optional ) : Used for flash attention . Defaults to None .
2024-02-28 06:36:50 +00:00
use_cuda_kernel : ( bool , optional ) : Whether to use cuda kernel . Defaults to True .
2024-03-25 05:40:34 +00:00
cu_seqlens ( torch . Tensor , optional ) : Holding the cumulative sum of sequence length .
high_precision ( Optional [ bool ] ) : Whether to use float32 for underlying calculations of float16 data to achieve higher precision , defaults to False .
2024-02-01 07:49:39 +00:00
"""
2024-01-30 02:31:46 +00:00
2024-03-25 05:40:34 +00:00
token_nums = hidden_states . size ( 0 )
2024-02-01 07:49:39 +00:00
if self . num_heads != self . num_key_value_heads :
2024-02-21 05:23:57 +00:00
query_states = torch . mm ( hidden_states , self . q_proj_weight ) . view ( - 1 , self . num_heads , self . head_dim )
key_states = torch . mm ( hidden_states , self . k_proj_weight ) . view ( - 1 , self . num_key_value_heads , self . head_dim )
value_states = torch . mm ( hidden_states , self . v_proj_weight ) . view ( - 1 , self . num_key_value_heads , self . head_dim )
2024-02-01 07:49:39 +00:00
else :
# fused qkv
hidden_states = hidden_states . expand ( 3 , - 1 , - 1 )
query_states , key_states , value_states = (
torch . bmm ( hidden_states , self . qkv_weight ) . view ( 3 , token_nums , self . num_heads , self . head_dim ) . unbind ( 0 )
)
2024-01-30 02:31:46 +00:00
2024-02-01 07:49:39 +00:00
block_size = k_cache . size ( - 2 )
2024-03-13 09:20:03 +00:00
2024-02-01 07:49:39 +00:00
if is_prompts :
2024-03-25 05:40:34 +00:00
if use_cuda_kernel and query_states . dtype != torch . float32 and use_flash_attn2 :
# flash attn 2 currently only supports FP16/BF16.
inference_ops . rotary_embedding ( query_states , key_states , cos_sin [ 0 ] , cos_sin [ 1 ] , high_precision )
inference_ops . context_kv_cache_memcpy (
key_states , value_states , k_cache , v_cache , sequence_lengths , cu_seqlens , block_tables , kv_seq_len
)
attn_output = flash_attn_varlen_func (
query_states ,
key_states ,
value_states ,
cu_seqlens_q = cu_seqlens ,
cu_seqlens_k = cu_seqlens ,
max_seqlen_q = kv_seq_len ,
max_seqlen_k = kv_seq_len ,
dropout_p = 0.0 ,
softmax_scale = sm_scale ,
causal = True ,
)
attn_output = attn_output . view ( token_nums , - 1 )
2024-03-13 09:20:03 +00:00
else :
rotary_embedding ( query_states , key_states , cos_sin [ 0 ] , cos_sin [ 1 ] )
2024-03-25 05:40:34 +00:00
attn_output = context_attention_unpadded (
q = query_states ,
k = key_states ,
v = value_states ,
k_cache = k_cache ,
v_cache = v_cache ,
context_lengths = sequence_lengths ,
block_tables = block_tables ,
block_size = block_size ,
output = output_tensor ,
max_seq_len = kv_seq_len ,
sm_scale = sm_scale ,
)
2024-02-01 07:49:39 +00:00
else :
2024-03-21 03:28:42 +00:00
if use_cuda_kernel :
2024-03-13 09:20:03 +00:00
inference_ops . rotary_embedding_and_cache_copy (
query_states ,
key_states ,
value_states ,
cos_sin [ 0 ] ,
cos_sin [ 1 ] ,
k_cache ,
v_cache ,
sequence_lengths ,
block_tables ,
2024-03-25 05:40:34 +00:00
high_precision ,
2024-02-28 06:36:50 +00:00
)
else :
decoding_fused_rotary_embedding (
query_states ,
key_states ,
value_states ,
cos_sin [ 0 ] ,
cos_sin [ 1 ] ,
k_cache ,
v_cache ,
block_tables ,
sequence_lengths ,
)
2024-02-01 07:49:39 +00:00
attn_output = flash_decoding_attention (
q = query_states ,
k_cache = k_cache ,
v_cache = v_cache ,
kv_seq_len = sequence_lengths ,
block_tables = block_tables ,
block_size = block_size ,
max_seq_len_in_batch = kv_seq_len ,
output = output_tensor ,
mid_output = fd_inter_tensor . mid_output ,
mid_output_lse = fd_inter_tensor . mid_output_lse ,
sm_scale = sm_scale ,
)
2024-02-21 05:23:57 +00:00
attn_output = torch . mm ( attn_output , self . o_proj_weight )
2024-02-01 07:49:39 +00:00
return attn_output
# NOTE This will cause the result to be different from the transformer in some cases.
class NopadLlamaMLP ( LlamaMLP ) :
def __init__ (
self ,
config : LlamaConfig ,
mlp_gproj_w : torch . Tensor = None ,
mlp_uproj_w : torch . Tensor = None ,
mlp_dproj_w : torch . Tensor = None ,
) :
""" This layer will replace the LlamaAttention.
Args :
config ( LlamaConfig ) : Holding the Llama model config .
mlp_gproj_w ( torch . Tensor , optional ) : The transposed gate_proj weight . Defaults to None .
mlp_uproj_w ( torch . Tensor , optional ) : The transposed up_proj weight . Defaults to None .
mlp_dproj_w ( torch . Tensor , optional ) : The transposed down_proj weight . Defaults to None .
"""
super ( ) . __init__ ( config )
2024-02-21 05:23:57 +00:00
self . gate_up_weight = torch . stack ( [ mlp_gproj_w , mlp_uproj_w ] , dim = 0 )
self . down_proj_weight = mlp_dproj_w
2024-02-06 11:38:25 +00:00
self . gate_proj = None
self . up_proj = None
2024-02-21 05:23:57 +00:00
self . down_proj = None
2024-02-01 07:49:39 +00:00
@staticmethod
def from_native_module ( module : LlamaMLP , * args , * * kwargs ) - > LlamaMLP :
""" Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
Args :
module ( LlamaMLP ) : The origin LlamaMLP layer .
"""
config = module . config
mlp_gproj_w = module . gate_proj . weight . transpose ( 0 , 1 )
mlp_uproj_w = module . up_proj . weight . transpose ( 0 , 1 )
mlp_dproj_w = module . down_proj . weight . transpose ( 0 , 1 )
mlp_layer = NopadLlamaMLP (
config = config ,
mlp_gproj_w = mlp_gproj_w ,
mlp_uproj_w = mlp_uproj_w ,
mlp_dproj_w = mlp_dproj_w ,
)
return mlp_layer
2024-02-21 05:23:57 +00:00
def forward ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
2024-02-01 07:49:39 +00:00
"""
Args :
2024-02-06 11:38:25 +00:00
hidden_states ( torch . Tensor ) : input to the layer of shape [ token_num , embed_dim ] .
2024-02-01 07:49:39 +00:00
"""
2024-02-06 11:38:25 +00:00
hidden_states = hidden_states . expand ( 2 , - 1 , - 1 )
gate_up_proj_out = torch . bmm ( hidden_states , self . gate_up_weight )
2024-03-25 05:40:34 +00:00
act_out = inference_ops . silu_and_mul ( gate_up_proj_out )
return torch . mm ( act_out , self . down_proj_weight )