2023-12-25 06:07:43 +00:00
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List , Optional , Tuple
import torch
2024-01-08 04:35:06 +00:00
from transformers . models . llama . modeling_llama import (
LlamaAttention ,
LlamaDecoderLayer ,
LlamaForCausalLM ,
LlamaModel ,
repeat_kv ,
)
2024-01-09 06:29:45 +00:00
from colossalai . inference . modeling . layers . attention import PagedAttention
2023-12-25 06:07:43 +00:00
from colossalai . inference . struct import BatchInfo
2024-01-08 04:35:06 +00:00
2024-01-09 06:29:45 +00:00
from flash_attn . bert_padding import index_first_axis , pad_input # noqa
2023-12-25 06:07:43 +00:00
def rotate_half ( x ) :
""" Rotates half the hidden dims of the input. """
x1 = x [ . . . , : x . shape [ - 1 ] / / 2 ]
x2 = x [ . . . , x . shape [ - 1 ] / / 2 : ]
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
def apply_rotary_pos_emb ( q , k , cos , sin , position_ids ) :
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos . squeeze ( 1 ) . squeeze ( 0 ) # [seq_len, dim]
sin = sin . squeeze ( 1 ) . squeeze ( 0 ) # [seq_len, dim]
cos = cos [ position_ids ] . unsqueeze ( 1 ) # [bs, 1, seq_len, dim]
sin = sin [ position_ids ] . unsqueeze ( 1 ) # [bs, 1, seq_len, dim]
q_embed = ( q * cos ) + ( rotate_half ( q ) * sin )
k_embed = ( k * cos ) + ( rotate_half ( k ) * sin )
return q_embed , k_embed
def llama_causal_lm_forward (
self : LlamaForCausalLM ,
batch : BatchInfo = None ,
k_caches : List [ torch . Tensor ] = None ,
v_caches : List [ torch . Tensor ] = None ,
2024-01-08 04:35:06 +00:00
padding_id : int = None ,
2023-12-25 06:07:43 +00:00
) :
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward (
self . model ,
batch = batch ,
k_caches = k_caches ,
v_caches = v_caches ,
2024-01-08 04:35:06 +00:00
padding_id = padding_id ,
2023-12-25 06:07:43 +00:00
)
logits = self . lm_head ( hidden_states )
return logits
def llama_model_forward (
self : LlamaModel ,
batch : BatchInfo = None ,
k_caches : List [ torch . Tensor ] = None ,
v_caches : List [ torch . Tensor ] = None ,
2024-01-08 04:35:06 +00:00
padding_id : int = None ,
2023-12-25 06:07:43 +00:00
) :
input_ids = batch . get_batch_inputs ( )
block_tables = batch . get_block_table_tensor ( )
sequence_lengths = batch . get_sequence_lengths ( )
2024-01-08 04:35:06 +00:00
attention_mask = batch . get_attn_mask ( padding_id )
if batch . is_prompts :
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
position_ids = generate_padding_position_id ( attention_mask )
else :
position_ids = ( attention_mask . sum ( dim = - 1 ) - 1 ) . reshape ( - 1 , 1 )
2023-12-25 06:07:43 +00:00
hidden_states = self . embed_tokens ( input_ids )
for layer_id , decoder_layer in enumerate ( self . layers ) :
hidden_states = decoder_layer (
hidden_states ,
position_ids = position_ids ,
block_tables = block_tables ,
k_cache = k_caches [ layer_id ] ,
v_cache = v_caches [ layer_id ] ,
is_prompts = batch . is_prompts ,
sequence_lengths = sequence_lengths ,
2024-01-08 04:35:06 +00:00
attention_mask = attention_mask ,
2023-12-25 06:07:43 +00:00
)
hidden_states = self . norm ( hidden_states )
return hidden_states
def llama_decoder_layer_forward (
self : LlamaDecoderLayer ,
hidden_states : torch . Tensor ,
position_ids : torch . LongTensor ,
block_tables : torch . Tensor = None ,
k_cache : torch . Tensor = None ,
v_cache : torch . Tensor = None ,
is_prompts : bool = True ,
sequence_lengths : int = None ,
2024-01-08 04:35:06 +00:00
attention_mask : torch . Tensor = None ,
2023-12-25 06:07:43 +00:00
) - > Tuple [ torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ] ] :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
# Self Attention
hidden_states = self . self_attn (
hidden_states = hidden_states ,
position_ids = position_ids ,
block_tables = block_tables ,
k_cache = k_cache ,
v_cache = v_cache ,
is_prompts = is_prompts ,
sequence_lengths = sequence_lengths ,
2024-01-08 04:35:06 +00:00
attention_mask = attention_mask ,
2023-12-25 06:07:43 +00:00
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
hidden_states = residual + hidden_states
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
def llama_attn_forward (
self : LlamaAttention ,
hidden_states : torch . Tensor ,
position_ids : torch . LongTensor ,
block_tables : torch . Tensor = None ,
k_cache : torch . Tensor = None ,
v_cache : torch . Tensor = None ,
is_prompts : bool = True ,
2024-01-02 10:30:11 +00:00
sequence_lengths : torch . Tensor = None ,
2024-01-08 04:35:06 +00:00
attention_mask : torch . Tensor = None ,
2023-12-25 06:07:43 +00:00
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
bsz , q_len , _ = hidden_states . size ( )
query_states = self . q_proj ( hidden_states ) . view ( bsz , q_len , self . num_heads , self . head_dim ) . transpose ( 1 , 2 )
key_states = self . k_proj ( hidden_states ) . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
value_states = self . v_proj ( hidden_states ) . view ( bsz , q_len , self . num_key_value_heads , self . head_dim ) . transpose ( 1 , 2 )
2024-01-08 04:35:06 +00:00
kv_seq_len = sequence_lengths [ 0 ] . item ( )
2023-12-25 06:07:43 +00:00
cos , sin = self . rotary_emb ( value_states , seq_len = kv_seq_len )
query_states , key_states = apply_rotary_pos_emb ( query_states , key_states , cos , sin , position_ids )
key_states = repeat_kv ( key_states , self . num_key_value_groups )
value_states = repeat_kv ( value_states , self . num_key_value_groups )
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . transpose ( 1 , 2 )
value_states = value_states . transpose ( 1 , 2 )
2024-01-03 10:50:26 +00:00
if is_prompts :
2024-01-09 06:29:45 +00:00
attn_output = PagedAttention . pad_context_forward (
2024-01-08 04:35:06 +00:00
query_states , key_states , value_states , k_cache , v_cache , sequence_lengths , block_tables , attention_mask
)
else :
2024-01-09 06:29:45 +00:00
attn_output = PagedAttention . pad_decoding_forward (
2024-01-09 05:52:53 +00:00
query_states , key_states , value_states , k_cache , v_cache , sequence_lengths , block_tables , attention_mask
2024-01-03 10:50:26 +00:00
)
2023-12-26 13:34:27 +00:00
2023-12-25 06:07:43 +00:00
attn_output = attn_output . view ( bsz , q_len , self . num_heads , self . head_dim )
attn_output = attn_output . reshape ( bsz , q_len , self . hidden_size )
attn_output = self . o_proj ( attn_output )
return attn_output
2024-01-02 10:30:11 +00:00
2024-01-08 04:35:06 +00:00
def generate_padding_position_id ( attention_mask : torch . Tensor ) - > torch . Tensor :
2024-01-02 10:30:11 +00:00
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
return position_ids