[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)

* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
pull/5349/head
yuehuayingxueluo 2024-02-01 15:49:39 +08:00 committed by GitHub
parent f8e456d202
commit 249644c23b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 510 additions and 341 deletions

View File

@ -2,8 +2,10 @@
from typing import List, Optional, Tuple
import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
@ -39,6 +41,14 @@ def llama_causal_lm_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaForCausalLM.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
@ -46,7 +56,7 @@ def llama_causal_lm_forward(
k_caches=k_caches,
v_caches=v_caches,
)
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@ -57,6 +67,13 @@ def llama_model_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
@ -74,7 +91,7 @@ def llama_model_forward(
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
@ -116,12 +133,30 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
residual=residual,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
@ -134,88 +169,213 @@ def llama_decoder_layer_forward(
sm_scale=sm_scale,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.mlp(hidden_states, residual)
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
):
"""This layer will replace the LlamaAttention.
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
Args:
config (LlamaConfig): Holding the Llama model config.
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.q_proj = None
self.k_proj = None
self.v_proj = None
block_size = k_cache.size(-2)
@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
Args:
module (LlamaAttention): The origin LlamaAttention layer.
"""
config = module.config
layer_idx = module.layer_idx
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
attn_layer = NopadLlamaAttention(
config=config,
layer_idx=layer_idx,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
return attn_layer
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
else:
# fused qkv
token_nums = hidden_states.size(0)
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
return attn_output
# NOTE This will cause the result to be different from the transformer in some cases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,
config: LlamaConfig,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the LlamaAttention.
Args:
config (LlamaConfig): Holding the Llama model config.
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
Args:
module (LlamaMLP): The origin LlamaMLP layer.
"""
config = module.config
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
mlp_layer = NopadLlamaMLP(
config=config,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
)
attn_output = attn_output.squeeze(1)
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
return mlp_layer
return attn_output
@torch.no_grad()
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
tmp_out = act_out * up_proj_out
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
return torch.addmm(residual, tmp_out, self.down_proj.weight)

View File

@ -2,7 +2,13 @@
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
)
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention
@ -53,6 +59,14 @@ def llama_causal_lm_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaForCausalLM.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
@ -71,6 +85,13 @@ def llama_model_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask()
@ -110,7 +131,7 @@ def llama_model_forward(
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
@ -131,7 +152,8 @@ def llama_model_forward(
sm_scale=sm_scale,
)
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
if batch.is_prompts:
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)
return hidden_states
@ -154,6 +176,23 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): _description_
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -185,108 +224,192 @@ def llama_decoder_layer_forward(
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
class PadLlamaAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
attn_qproj_w: torch.nn.Parameter = None,
attn_kproj_w: torch.nn.Parameter = None,
attn_vproj_w: torch.nn.Parameter = None,
attn_oproj_w: torch.nn.Parameter = None,
):
"""This layer will replace the LlamaAttention.
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
Args:
config (LlamaConfig): Holding the Llama model config.
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None.
attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None.
attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None.
attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj.weight = attn_qproj_w
self.k_proj.weight = attn_kproj_w
self.v_proj.weight = attn_vproj_w
self.o_proj.weight = attn_oproj_w
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention
Args:
module (LlamaAttention): The origin LlamaAttention layer.
"""
config = module.config
layer_idx = module.layer_idx
attn_qproj_w = module.q_proj.weight
attn_kproj_w = module.k_proj.weight
attn_vproj_w = module.v_proj.weight
attn_oproj_w = module.o_proj.weight
attn_layer = PadLlamaAttention(
config=config,
layer_idx=layer_idx,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)
return attn_layer
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)`
where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
)
else:
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
)
else:
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
attn_output = PagedAttention.pad_decoding_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
return attn_output
@torch.no_grad()

View File

@ -1,25 +1,18 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaSdpaAttention,
)
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.modeling.models.nopadding_llama import (
llama_attn_forward,
NopadLlamaAttention,
NopadLlamaMLP,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
nopad_mlp,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -50,6 +43,27 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
}
policy[LlamaForCausalLM] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
policy[LlamaDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadLlamaMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadLlamaAttention,
),
]
)
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
@ -68,28 +82,6 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = nopad_mlp
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()

View File

@ -1,18 +1,10 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaSdpaAttention,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.modeling.models.padding_llama import (
llama_attn_forward,
PadLlamaAttention,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
@ -49,105 +41,16 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
policy[LlamaDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn",
target_module=PadLlamaAttention,
),
]
)
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = llama_causal_lm_forward
@ -166,24 +69,6 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
)
infer_forward = llama_attn_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()

View File

@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
stride_o_lset,
stride_o_lseh,
stride_o_lseb,
stride_ob,
stride_ol,
stride_ot,
stride_oh,
stride_od,
BLOCK_KV: tl.constexpr,
@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
m_i = m_ij
acc = acc / l
offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
tl.store(O + offsets_O, acc.to(O.type.element_ty))
return
@ -212,7 +211,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
output (torch.Tensor): [bsz, num_heads, head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@ -294,7 +293,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads)
@ -314,7 +313,6 @@ def flash_decoding_attention(
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
BLOCK_KV=block_size,
HEAD_DIM=head_dim,
)

View File

@ -25,10 +25,20 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
done
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
done

View File

@ -69,6 +69,7 @@ def torch_attn_ref(
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
)
out = out.transpose(1, 2).contiguous()
out = out.squeeze(1)
return out

View File

@ -94,7 +94,7 @@ def test_flash_decoding(
max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@ -189,7 +189,7 @@ def bench_kernel(
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)