|
|
|
@ -2,7 +2,6 @@
|
|
|
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from torch.nn import Parameter |
|
|
|
|
from transformers.models.llama.modeling_llama import ( |
|
|
|
|
LlamaAttention, |
|
|
|
|
LlamaConfig, |
|
|
|
@ -82,19 +81,21 @@ def llama_model_forward(
|
|
|
|
|
|
|
|
|
|
if batch.is_prompts: |
|
|
|
|
output_tensor = torch.zeros( |
|
|
|
|
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
output_tensor = torch.zeros( |
|
|
|
|
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
|
|
|
|
) |
|
|
|
|
sm_scale = 1.0 / (batch.head_dim**0.5) |
|
|
|
|
|
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
residual = None |
|
|
|
|
|
|
|
|
|
for layer_id, decoder_layer in enumerate(self.layers): |
|
|
|
|
hidden_states = decoder_layer( |
|
|
|
|
hidden_states, residual = decoder_layer( |
|
|
|
|
hidden_states, |
|
|
|
|
residual=residual, |
|
|
|
|
block_tables=block_tables, |
|
|
|
|
k_cache=k_caches[layer_id], |
|
|
|
|
v_cache=v_caches[layer_id], |
|
|
|
@ -111,8 +112,9 @@ def llama_model_forward(
|
|
|
|
|
if batch.is_prompts: |
|
|
|
|
last_token_indexs = sequence_lengths.cumsum(dim=-1) |
|
|
|
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous() |
|
|
|
|
residual = residual[last_token_indexs - 1].contiguous() |
|
|
|
|
norm_output = torch.empty_like(hidden_states) |
|
|
|
|
hidden_states = self.norm(hidden_states, norm_output) |
|
|
|
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual) |
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
@ -120,6 +122,7 @@ def llama_model_forward(
|
|
|
|
|
def llama_decoder_layer_forward( |
|
|
|
|
self: LlamaDecoderLayer, |
|
|
|
|
hidden_states: torch.Tensor, |
|
|
|
|
residual: torch.Tensor, |
|
|
|
|
block_tables: torch.Tensor = None, |
|
|
|
|
k_cache: torch.Tensor = None, |
|
|
|
|
v_cache: torch.Tensor = None, |
|
|
|
@ -136,6 +139,7 @@ def llama_decoder_layer_forward(
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. |
|
|
|
|
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. |
|
|
|
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], |
|
|
|
|
storing mapping of token_position_id -> block_id. Defaults to None. |
|
|
|
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. |
|
|
|
@ -151,12 +155,10 @@ def llama_decoder_layer_forward(
|
|
|
|
|
sm_scale (int, optional): Used for flash attention. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
hidden_states = self.input_layernorm(hidden_states, norm_output) |
|
|
|
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual) |
|
|
|
|
# Self Attention |
|
|
|
|
hidden_states = self.self_attn( |
|
|
|
|
hidden_states=hidden_states, |
|
|
|
|
residual=residual, |
|
|
|
|
block_tables=block_tables, |
|
|
|
|
k_cache=k_cache, |
|
|
|
|
v_cache=v_cache, |
|
|
|
@ -170,11 +172,10 @@ def llama_decoder_layer_forward(
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Fully Connected |
|
|
|
|
residual = hidden_states |
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states, norm_output) |
|
|
|
|
hidden_states = self.mlp(hidden_states, residual) |
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual) |
|
|
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
return hidden_states, residual |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NopadLlamaAttention(LlamaAttention): |
|
|
|
@ -198,13 +199,15 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
super().__init__(config, layer_idx) |
|
|
|
|
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False) |
|
|
|
|
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False) |
|
|
|
|
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False) |
|
|
|
|
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False) |
|
|
|
|
self.q_proj_weight = attn_qproj_w |
|
|
|
|
self.k_proj_weight = attn_kproj_w |
|
|
|
|
self.v_proj_weight = attn_vproj_w |
|
|
|
|
self.o_proj_weight = attn_oproj_w |
|
|
|
|
|
|
|
|
|
if self.num_heads == self.num_key_value_heads: |
|
|
|
|
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight] |
|
|
|
|
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight] |
|
|
|
|
self.qkv_weight = torch.stack(qkv_weight_list, dim=0) |
|
|
|
|
|
|
|
|
|
self.q_proj = None |
|
|
|
|
self.k_proj = None |
|
|
|
|
self.v_proj = None |
|
|
|
@ -239,7 +242,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
def forward( |
|
|
|
|
self, |
|
|
|
|
hidden_states: torch.Tensor, |
|
|
|
|
residual: torch.Tensor, |
|
|
|
|
block_tables: torch.Tensor = None, |
|
|
|
|
k_cache: torch.Tensor = None, |
|
|
|
|
v_cache: torch.Tensor = None, |
|
|
|
@ -254,7 +256,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. |
|
|
|
|
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. |
|
|
|
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], |
|
|
|
|
storing mapping of token_position_id -> block_id. Defaults to None. |
|
|
|
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. |
|
|
|
@ -270,9 +271,9 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if self.num_heads != self.num_key_value_heads: |
|
|
|
|
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim) |
|
|
|
|
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) |
|
|
|
|
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim) |
|
|
|
|
query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) |
|
|
|
|
key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) |
|
|
|
|
value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) |
|
|
|
|
else: |
|
|
|
|
# fused qkv |
|
|
|
|
token_nums = hidden_states.size(0) |
|
|
|
@ -324,8 +325,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|
|
|
|
sm_scale=sm_scale, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
attn_output = attn_output.view(-1, self.hidden_size) |
|
|
|
|
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight) |
|
|
|
|
attn_output = torch.mm(attn_output, self.o_proj_weight) |
|
|
|
|
|
|
|
|
|
return attn_output |
|
|
|
|
|
|
|
|
@ -348,10 +348,11 @@ class NopadLlamaMLP(LlamaMLP):
|
|
|
|
|
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
super().__init__(config) |
|
|
|
|
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False) |
|
|
|
|
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False) |
|
|
|
|
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0) |
|
|
|
|
self.down_proj_weight = mlp_dproj_w |
|
|
|
|
self.gate_proj = None |
|
|
|
|
self.up_proj = None |
|
|
|
|
self.down_proj = None |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP: |
|
|
|
@ -375,14 +376,13 @@ class NopadLlamaMLP(LlamaMLP):
|
|
|
|
|
|
|
|
|
|
return mlp_layer |
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: |
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. |
|
|
|
|
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj. |
|
|
|
|
""" |
|
|
|
|
hidden_states = hidden_states.expand(2, -1, -1) |
|
|
|
|
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) |
|
|
|
|
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) |
|
|
|
|
tmp_out = act_out * gate_up_proj_out[1] |
|
|
|
|
return torch.addmm(residual, tmp_out, self.down_proj.weight) |
|
|
|
|
return torch.mm(tmp_out, self.down_proj_weight) |
|
|
|
|