mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.pypull/5349/head
parent
f8e456d202
commit
249644c23b
|
@ -2,8 +2,10 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
|
LlamaConfig,
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
|
@ -39,6 +41,14 @@ def llama_causal_lm_forward(
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""This function will replace the forward function of LlamaForCausalLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
|
||||||
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = llama_model_forward(
|
hidden_states = llama_model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -46,7 +56,7 @@ def llama_causal_lm_forward(
|
||||||
k_caches=k_caches,
|
k_caches=k_caches,
|
||||||
v_caches=v_caches,
|
v_caches=v_caches,
|
||||||
)
|
)
|
||||||
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
|
logits = torch.mm(hidden_states, self.lm_head.weight)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,6 +67,13 @@ def llama_model_forward(
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""This function will replace the forward function of LlamaModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||||
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
|
"""
|
||||||
input_ids = batch.get_1D_inputs()
|
input_ids = batch.get_1D_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
|
|
||||||
|
@ -74,7 +91,7 @@ def llama_model_forward(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
)
|
)
|
||||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||||
|
|
||||||
|
@ -116,12 +133,30 @@ def llama_decoder_layer_forward(
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
|
||||||
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||||
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||||
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||||
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||||
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||||
|
storing intermediate values in flash-decoding. Defaults to None.
|
||||||
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
v_cache=v_cache,
|
v_cache=v_cache,
|
||||||
|
@ -134,88 +169,213 @@ def llama_decoder_layer_forward(
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, residual)
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
class NopadLlamaAttention(LlamaAttention):
|
||||||
@torch.no_grad()
|
def __init__(
|
||||||
def llama_attn_forward(
|
self,
|
||||||
self: LlamaAttention,
|
config: LlamaConfig,
|
||||||
hidden_states: torch.Tensor,
|
layer_idx: Optional[int] = None,
|
||||||
block_tables: torch.Tensor = None,
|
attn_qproj_w: torch.Tensor = None,
|
||||||
k_cache: torch.Tensor = None,
|
attn_kproj_w: torch.Tensor = None,
|
||||||
v_cache: torch.Tensor = None,
|
attn_vproj_w: torch.Tensor = None,
|
||||||
is_prompts: bool = True,
|
attn_oproj_w: torch.Tensor = None,
|
||||||
sequence_lengths: torch.Tensor = None,
|
):
|
||||||
kv_seq_len: int = 0,
|
"""This layer will replace the LlamaAttention.
|
||||||
cos_sin: Tuple[torch.Tensor] = None,
|
|
||||||
fd_inter_tensor: FDIntermTensors = None,
|
|
||||||
output_tensor: torch.Tensor = None,
|
|
||||||
sm_scale: int = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
|
|
||||||
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
|
|
||||||
-1, self.num_key_value_heads, self.head_dim
|
|
||||||
)
|
|
||||||
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
|
|
||||||
-1, self.num_key_value_heads, self.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
Args:
|
||||||
|
config (LlamaConfig): Holding the Llama model config.
|
||||||
|
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
|
||||||
|
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||||
|
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||||
|
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||||
|
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
|
||||||
|
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
|
||||||
|
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
|
||||||
|
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
|
||||||
|
if self.num_heads == self.num_key_value_heads:
|
||||||
|
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
|
||||||
|
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||||
|
self.q_proj = None
|
||||||
|
self.k_proj = None
|
||||||
|
self.v_proj = None
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
@staticmethod
|
||||||
|
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||||
|
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
|
||||||
|
|
||||||
if is_prompts:
|
Args:
|
||||||
attn_output = context_attention_unpadded(
|
module (LlamaAttention): The origin LlamaAttention layer.
|
||||||
q=query_states,
|
"""
|
||||||
k=key_states,
|
config = module.config
|
||||||
v=value_states,
|
layer_idx = module.layer_idx
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
|
||||||
context_lengths=sequence_lengths,
|
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
|
||||||
block_tables=block_tables,
|
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
|
||||||
block_size=block_size,
|
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
||||||
output=output_tensor,
|
|
||||||
max_seq_len=kv_seq_len,
|
attn_layer = NopadLlamaAttention(
|
||||||
sm_scale=sm_scale,
|
config=config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
attn_qproj_w=attn_qproj_w,
|
||||||
|
attn_kproj_w=attn_kproj_w,
|
||||||
|
attn_vproj_w=attn_vproj_w,
|
||||||
|
attn_oproj_w=attn_oproj_w,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
return attn_layer
|
||||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
|
||||||
attn_output = flash_decoding_attention(
|
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
||||||
q=query_states,
|
@torch.no_grad()
|
||||||
k_cache=k_cache,
|
def forward(
|
||||||
v_cache=v_cache,
|
self,
|
||||||
kv_seq_len=sequence_lengths,
|
hidden_states: torch.Tensor,
|
||||||
block_tables=block_tables,
|
residual: torch.Tensor,
|
||||||
block_size=block_size,
|
block_tables: torch.Tensor = None,
|
||||||
max_seq_len_in_batch=kv_seq_len,
|
k_cache: torch.Tensor = None,
|
||||||
output=output_tensor,
|
v_cache: torch.Tensor = None,
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
is_prompts: bool = True,
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
sequence_lengths: torch.Tensor = None,
|
||||||
sm_scale=sm_scale,
|
kv_seq_len: int = 0,
|
||||||
|
cos_sin: Tuple[torch.Tensor] = None,
|
||||||
|
fd_inter_tensor: FDIntermTensors = None,
|
||||||
|
output_tensor: torch.Tensor = None,
|
||||||
|
sm_scale: int = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
|
||||||
|
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
|
||||||
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||||
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||||
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||||
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||||
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||||
|
storing intermediate values in flash-decoding. Defaults to None.
|
||||||
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.num_heads != self.num_key_value_heads:
|
||||||
|
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
|
||||||
|
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||||
|
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
|
||||||
|
else:
|
||||||
|
# fused qkv
|
||||||
|
token_nums = hidden_states.size(0)
|
||||||
|
hidden_states = hidden_states.expand(3, -1, -1)
|
||||||
|
query_states, key_states, value_states = (
|
||||||
|
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
|
||||||
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
|
if is_prompts:
|
||||||
|
attn_output = context_attention_unpadded(
|
||||||
|
q=query_states,
|
||||||
|
k=key_states,
|
||||||
|
v=value_states,
|
||||||
|
k_cache=k_cache,
|
||||||
|
v_cache=v_cache,
|
||||||
|
context_lengths=sequence_lengths,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_size=block_size,
|
||||||
|
output=output_tensor,
|
||||||
|
max_seq_len=kv_seq_len,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
|
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
|
attn_output = flash_decoding_attention(
|
||||||
|
q=query_states,
|
||||||
|
k_cache=k_cache,
|
||||||
|
v_cache=v_cache,
|
||||||
|
kv_seq_len=sequence_lengths,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_size=block_size,
|
||||||
|
max_seq_len_in_batch=kv_seq_len,
|
||||||
|
output=output_tensor,
|
||||||
|
mid_output=fd_inter_tensor.mid_output,
|
||||||
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(-1, self.hidden_size)
|
||||||
|
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE This will cause the result to be different from the transformer in some cases.
|
||||||
|
class NopadLlamaMLP(LlamaMLP):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
mlp_gproj_w: torch.Tensor = None,
|
||||||
|
mlp_uproj_w: torch.Tensor = None,
|
||||||
|
mlp_dproj_w: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
"""This layer will replace the LlamaAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (LlamaConfig): Holding the Llama model config.
|
||||||
|
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||||
|
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||||
|
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
|
||||||
|
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
|
||||||
|
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||||
|
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (LlamaMLP): The origin LlamaMLP layer.
|
||||||
|
"""
|
||||||
|
config = module.config
|
||||||
|
|
||||||
|
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
||||||
|
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
||||||
|
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
||||||
|
|
||||||
|
mlp_layer = NopadLlamaMLP(
|
||||||
|
config=config,
|
||||||
|
mlp_gproj_w=mlp_gproj_w,
|
||||||
|
mlp_uproj_w=mlp_uproj_w,
|
||||||
|
mlp_dproj_w=mlp_dproj_w,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.squeeze(1)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
|
return mlp_layer
|
||||||
attn_output = attn_output.reshape(-1, self.hidden_size)
|
|
||||||
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
|
|
||||||
|
|
||||||
return attn_output
|
@torch.no_grad()
|
||||||
|
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
@torch.no_grad()
|
Args:
|
||||||
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
|
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
|
||||||
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
|
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
|
||||||
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
|
"""
|
||||||
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
|
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
|
||||||
tmp_out = act_out * up_proj_out
|
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
|
||||||
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
|
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
|
||||||
|
tmp_out = act_out * up_proj_out
|
||||||
|
return torch.addmm(residual, tmp_out, self.down_proj.weight)
|
||||||
|
|
|
@ -2,7 +2,13 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaConfig,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
)
|
||||||
|
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||||
|
@ -53,6 +59,14 @@ def llama_causal_lm_forward(
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""This function will replace the forward function of LlamaForCausalLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
|
||||||
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = llama_model_forward(
|
hidden_states = llama_model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -71,6 +85,13 @@ def llama_model_forward(
|
||||||
k_caches: List[torch.Tensor] = None,
|
k_caches: List[torch.Tensor] = None,
|
||||||
v_caches: List[torch.Tensor] = None,
|
v_caches: List[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""This function will replace the forward function of LlamaModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
|
||||||
|
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||||
|
"""
|
||||||
input_ids = batch.get_batch_inputs()
|
input_ids = batch.get_batch_inputs()
|
||||||
block_tables = batch.get_block_table_tensor()
|
block_tables = batch.get_block_table_tensor()
|
||||||
attention_mask = batch.get_attn_mask()
|
attention_mask = batch.get_attn_mask()
|
||||||
|
@ -110,7 +131,7 @@ def llama_model_forward(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_tensor = torch.zeros(
|
output_tensor = torch.zeros(
|
||||||
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||||
)
|
)
|
||||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||||
|
|
||||||
|
@ -131,7 +152,8 @@ def llama_model_forward(
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
if batch.is_prompts:
|
||||||
|
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@ -154,6 +176,23 @@ def llama_decoder_layer_forward(
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): _description_
|
||||||
|
position_ids (torch.LongTensor), The position ids of input sequences.
|
||||||
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||||
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||||
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||||
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||||
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
|
||||||
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
@ -185,108 +224,192 @@ def llama_decoder_layer_forward(
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
class PadLlamaAttention(LlamaAttention):
|
||||||
@torch.no_grad()
|
def __init__(
|
||||||
def llama_attn_forward(
|
self,
|
||||||
self: LlamaAttention,
|
config: LlamaConfig,
|
||||||
hidden_states: torch.Tensor,
|
layer_idx: Optional[int] = None,
|
||||||
position_ids: torch.LongTensor,
|
attn_qproj_w: torch.nn.Parameter = None,
|
||||||
block_tables: torch.Tensor = None,
|
attn_kproj_w: torch.nn.Parameter = None,
|
||||||
k_cache: torch.Tensor = None,
|
attn_vproj_w: torch.nn.Parameter = None,
|
||||||
v_cache: torch.Tensor = None,
|
attn_oproj_w: torch.nn.Parameter = None,
|
||||||
is_prompts: bool = True,
|
):
|
||||||
sequence_lengths: torch.Tensor = None,
|
"""This layer will replace the LlamaAttention.
|
||||||
attention_mask: torch.Tensor = None,
|
|
||||||
kv_seq_len: int = 0,
|
|
||||||
cos_sin: Tuple[torch.Tensor] = None,
|
|
||||||
fd_inter_tensor: FDIntermTensors = None,
|
|
||||||
output_tensor: torch.Tensor = None,
|
|
||||||
sm_scale: int = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
Args:
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
config (LlamaConfig): Holding the Llama model config.
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
|
||||||
|
attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None.
|
||||||
|
attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None.
|
||||||
|
attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None.
|
||||||
|
attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None.
|
||||||
|
"""
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.q_proj.weight = attn_qproj_w
|
||||||
|
self.k_proj.weight = attn_kproj_w
|
||||||
|
self.v_proj.weight = attn_vproj_w
|
||||||
|
self.o_proj.weight = attn_oproj_w
|
||||||
|
|
||||||
if HAS_TRITON:
|
@staticmethod
|
||||||
if is_prompts:
|
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||||
if attention_mask is not None:
|
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention
|
||||||
query_states, key_states, value_states, indices = unpading_input(
|
|
||||||
query_states, key_states, value_states, attention_mask
|
Args:
|
||||||
|
module (LlamaAttention): The origin LlamaAttention layer.
|
||||||
|
"""
|
||||||
|
config = module.config
|
||||||
|
layer_idx = module.layer_idx
|
||||||
|
|
||||||
|
attn_qproj_w = module.q_proj.weight
|
||||||
|
attn_kproj_w = module.k_proj.weight
|
||||||
|
attn_vproj_w = module.v_proj.weight
|
||||||
|
attn_oproj_w = module.o_proj.weight
|
||||||
|
|
||||||
|
attn_layer = PadLlamaAttention(
|
||||||
|
config=config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
attn_qproj_w=attn_qproj_w,
|
||||||
|
attn_kproj_w=attn_kproj_w,
|
||||||
|
attn_vproj_w=attn_vproj_w,
|
||||||
|
attn_oproj_w=attn_oproj_w,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_layer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.LongTensor,
|
||||||
|
block_tables: torch.Tensor = None,
|
||||||
|
k_cache: torch.Tensor = None,
|
||||||
|
v_cache: torch.Tensor = None,
|
||||||
|
is_prompts: bool = True,
|
||||||
|
sequence_lengths: torch.Tensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
kv_seq_len: int = 0,
|
||||||
|
cos_sin: Tuple[torch.Tensor] = None,
|
||||||
|
fd_inter_tensor: FDIntermTensors = None,
|
||||||
|
output_tensor: torch.Tensor = None,
|
||||||
|
sm_scale: int = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
|
||||||
|
position_ids (torch.LongTensor), The position ids of input sequences.
|
||||||
|
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||||
|
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||||
|
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||||
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
||||||
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
|
||||||
|
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)`
|
||||||
|
where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
|
||||||
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
||||||
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||||
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||||
|
storing intermediate values in flash-decoding. Defaults to None.
|
||||||
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
"""
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
if is_prompts:
|
||||||
|
if attention_mask is not None:
|
||||||
|
query_states, key_states, value_states, indices = unpading_input(
|
||||||
|
query_states, key_states, value_states, attention_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
else:
|
||||||
|
query_states = query_states.squeeze(dim=1)
|
||||||
|
key_states = key_states.squeeze(dim=1)
|
||||||
|
value_states = value_states.squeeze(dim=1)
|
||||||
|
|
||||||
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
|
||||||
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
|
if is_prompts:
|
||||||
|
attn_output = context_attention_unpadded(
|
||||||
|
q=query_states,
|
||||||
|
k=key_states,
|
||||||
|
v=value_states,
|
||||||
|
k_cache=k_cache,
|
||||||
|
v_cache=v_cache,
|
||||||
|
context_lengths=sequence_lengths,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_size=block_size,
|
||||||
|
output=output_tensor,
|
||||||
|
max_seq_len=kv_seq_len,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||||
|
else:
|
||||||
|
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
|
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||||
|
attn_output = flash_decoding_attention(
|
||||||
|
q=query_states,
|
||||||
|
k_cache=k_cache,
|
||||||
|
v_cache=v_cache,
|
||||||
|
kv_seq_len=sequence_lengths,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_size=block_size,
|
||||||
|
max_seq_len_in_batch=kv_seq_len,
|
||||||
|
output=output_tensor,
|
||||||
|
mid_output=fd_inter_tensor.mid_output,
|
||||||
|
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.squeeze(1)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if is_prompts:
|
||||||
|
attn_output = PagedAttention.pad_context_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables,
|
||||||
|
attention_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
attn_output = PagedAttention.pad_decoding_forward(
|
||||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
query_states,
|
||||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
key_states,
|
||||||
else:
|
value_states,
|
||||||
query_states = query_states.squeeze(dim=1)
|
k_cache,
|
||||||
key_states = key_states.squeeze(dim=1)
|
v_cache,
|
||||||
value_states = value_states.squeeze(dim=1)
|
sequence_lengths,
|
||||||
|
block_tables,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
return attn_output
|
||||||
|
|
||||||
if is_prompts:
|
|
||||||
attn_output = context_attention_unpadded(
|
|
||||||
q=query_states,
|
|
||||||
k=key_states,
|
|
||||||
v=value_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
context_lengths=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
output=output_tensor,
|
|
||||||
max_seq_len=kv_seq_len,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
|
||||||
else:
|
|
||||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
|
||||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
|
||||||
attn_output = flash_decoding_attention(
|
|
||||||
q=query_states,
|
|
||||||
k_cache=k_cache,
|
|
||||||
v_cache=v_cache,
|
|
||||||
kv_seq_len=sequence_lengths,
|
|
||||||
block_tables=block_tables,
|
|
||||||
block_size=block_size,
|
|
||||||
max_seq_len_in_batch=kv_seq_len,
|
|
||||||
output=output_tensor,
|
|
||||||
mid_output=fd_inter_tensor.mid_output,
|
|
||||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
)
|
|
||||||
attn_output = attn_output.squeeze(1)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
if is_prompts:
|
|
||||||
attn_output = PagedAttention.pad_context_forward(
|
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = PagedAttention.pad_decoding_forward(
|
|
||||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -1,25 +1,18 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import (
|
from torch.nn import Parameter
|
||||||
LlamaAttention,
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaMLP,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaRMSNorm,
|
|
||||||
LlamaSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||||
llama_attn_forward,
|
NopadLlamaAttention,
|
||||||
|
NopadLlamaMLP,
|
||||||
llama_causal_lm_forward,
|
llama_causal_lm_forward,
|
||||||
llama_decoder_layer_forward,
|
llama_decoder_layer_forward,
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
nopad_mlp,
|
|
||||||
)
|
)
|
||||||
from colossalai.inference.utils import init_to_get_rotary
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
|
|
||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
@ -50,6 +43,27 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
|
||||||
|
}
|
||||||
|
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="mlp",
|
||||||
|
target_module=NopadLlamaMLP,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn",
|
||||||
|
target_module=NopadLlamaAttention,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
infer_forward = llama_causal_lm_forward
|
infer_forward = llama_causal_lm_forward
|
||||||
|
@ -68,28 +82,6 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_forward = nopad_mlp
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
|
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = None
|
infer_forward = None
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
infer_forward = get_triton_rmsnorm_forward()
|
infer_forward = get_triton_rmsnorm_forward()
|
||||||
|
|
|
@ -1,18 +1,10 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaRMSNorm,
|
|
||||||
LlamaSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
from colossalai.inference.modeling.models.padding_llama import (
|
from colossalai.inference.modeling.models.padding_llama import (
|
||||||
llama_attn_forward,
|
PadLlamaAttention,
|
||||||
llama_causal_lm_forward,
|
llama_causal_lm_forward,
|
||||||
llama_decoder_layer_forward,
|
llama_decoder_layer_forward,
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
|
@ -49,105 +41,16 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
decoder_attribute_replacement = {
|
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
|
||||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
|
||||||
// self.shard_config.tensor_parallel_size,
|
|
||||||
}
|
|
||||||
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
|
|
||||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
|
||||||
|
|
||||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
sub_module_replacement=[
|
||||||
sub_module_replacement=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="self_attn",
|
||||||
suffix="self_attn.q_proj",
|
target_module=PadLlamaAttention,
|
||||||
target_module=ColCaiQuantLinear,
|
),
|
||||||
kwargs={"split_num": 1},
|
]
|
||||||
),
|
)
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.k_proj",
|
|
||||||
target_module=ColCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.v_proj",
|
|
||||||
target_module=ColCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.o_proj",
|
|
||||||
target_module=RowCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.gate_proj",
|
|
||||||
target_module=ColCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.up_proj",
|
|
||||||
target_module=ColCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.down_proj",
|
|
||||||
target_module=RowCaiQuantLinear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
|
|
||||||
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
|
|
||||||
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
|
|
||||||
ColW8A8BFP32OFP32Linear,
|
|
||||||
RowW8A8B8O8Linear,
|
|
||||||
RowW8A8BFP32O32LinearSiLU,
|
|
||||||
RowW8A8BFP32OFP32Linear,
|
|
||||||
)
|
|
||||||
|
|
||||||
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
|
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
|
||||||
sub_module_replacement=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.q_proj",
|
|
||||||
target_module=RowW8A8B8O8Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.k_proj",
|
|
||||||
target_module=RowW8A8B8O8Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.v_proj",
|
|
||||||
target_module=RowW8A8B8O8Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="self_attn.o_proj",
|
|
||||||
target_module=ColW8A8BFP32OFP32Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.gate_proj",
|
|
||||||
target_module=RowW8A8BFP32O32LinearSiLU,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.up_proj",
|
|
||||||
target_module=RowW8A8BFP32OFP32Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="mlp.down_proj",
|
|
||||||
target_module=ColW8A8BFP32OFP32Linear,
|
|
||||||
kwargs={"split_num": 1},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.shard_config._infer()
|
self.shard_config._infer()
|
||||||
|
|
||||||
infer_forward = llama_causal_lm_forward
|
infer_forward = llama_causal_lm_forward
|
||||||
|
@ -166,24 +69,6 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = llama_attn_forward
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_forward = None
|
infer_forward = None
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
infer_forward = get_triton_rmsnorm_forward()
|
infer_forward = get_triton_rmsnorm_forward()
|
||||||
|
|
|
@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
stride_o_lset,
|
stride_o_lset,
|
||||||
stride_o_lseh,
|
stride_o_lseh,
|
||||||
stride_o_lseb,
|
stride_o_lseb,
|
||||||
stride_ob,
|
stride_ot,
|
||||||
stride_ol,
|
|
||||||
stride_oh,
|
stride_oh,
|
||||||
stride_od,
|
stride_od,
|
||||||
BLOCK_KV: tl.constexpr,
|
BLOCK_KV: tl.constexpr,
|
||||||
|
@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||||
m_i = m_ij
|
m_i = m_ij
|
||||||
|
|
||||||
acc = acc / l
|
acc = acc / l
|
||||||
offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
|
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
||||||
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -212,7 +211,7 @@ def flash_decoding_attention(
|
||||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||||
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
|
output (torch.Tensor): [bsz, num_heads, head_dim]
|
||||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||||
|
@ -294,7 +293,7 @@ def flash_decoding_attention(
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||||
|
|
||||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||||
|
|
||||||
|
@ -314,7 +313,6 @@ def flash_decoding_attention(
|
||||||
output.stride(0),
|
output.stride(0),
|
||||||
output.stride(1),
|
output.stride(1),
|
||||||
output.stride(2),
|
output.stride(2),
|
||||||
output.stride(3),
|
|
||||||
BLOCK_KV=block_size,
|
BLOCK_KV=block_size,
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,10 +25,20 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||||
# benchmark llama2-7b one single GPU
|
# benchmark llama2-7b one single GPU
|
||||||
|
|
||||||
for bsz in 16 32 64; do
|
for bsz in 16 32 64; do
|
||||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
for bsz in 16 32 64; do
|
for bsz in 16 32 64; do
|
||||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
for bsz in 16 32 64; do
|
||||||
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
for bsz in 16 32 64; do
|
||||||
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
|
||||||
done
|
done
|
||||||
|
|
|
@ -69,6 +69,7 @@ def torch_attn_ref(
|
||||||
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
||||||
)
|
)
|
||||||
out = out.transpose(1, 2).contiguous()
|
out = out.transpose(1, 2).contiguous()
|
||||||
|
out = out.squeeze(1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ def test_flash_decoding(
|
||||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
max_seq_len_in_b = kv_seq_lengths.max().item()
|
||||||
# The maximum block length splitted on kv should be the kv cache block size
|
# The maximum block length splitted on kv should be the kv cache block size
|
||||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||||
mid_output = torch.empty(
|
mid_output = torch.empty(
|
||||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||||
)
|
)
|
||||||
|
@ -189,7 +189,7 @@ def bench_kernel(
|
||||||
block_tables = block_tables.to(device=device)
|
block_tables = block_tables.to(device=device)
|
||||||
# the maximum block length splitted on kv should be the kv cache block size
|
# the maximum block length splitted on kv should be the kv cache block size
|
||||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||||
mid_output = torch.empty(
|
mid_output = torch.empty(
|
||||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue