mirror of https://github.com/hpcaitech/ColossalAI
184 lines
7.4 KiB
Python
184 lines
7.4 KiB
Python
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
|
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
|
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
inference_ops = InferenceOpsLoader().load()
|
|
|
|
logger = get_dist_logger(__name__)
|
|
|
|
|
|
class NopadBaichuanAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
attn_qproj_w: torch.Tensor = None,
|
|
attn_kproj_w: torch.Tensor = None,
|
|
attn_vproj_w: torch.Tensor = None,
|
|
attn_oproj_w: torch.Tensor = None,
|
|
):
|
|
"""This layer will replace the BaichuanAttention.
|
|
|
|
Args:
|
|
config (BaichuanConfig): Holding the Baichuan model config.
|
|
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
|
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
|
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
|
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
|
"""
|
|
super().__init__()
|
|
self.o_proj_weight = attn_oproj_w
|
|
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
# Used to adapt llama_base_attn_forward
|
|
self.num_key_value_heads = self.num_heads
|
|
|
|
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
|
|
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
|
|
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
|
|
|
|
Args:
|
|
module (nn.Module): The origin BaichuanAttention layer.
|
|
"""
|
|
|
|
config = module.config
|
|
|
|
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
|
|
|
|
attn_qproj_w = q_proj_w.transpose(0, 1)
|
|
attn_kproj_w = k_proj_w.transpose(0, 1)
|
|
attn_vproj_w = v_proj_w.transpose(0, 1)
|
|
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
|
|
|
attn_layer = NopadBaichuanAttention(
|
|
config=config,
|
|
attn_qproj_w=attn_qproj_w,
|
|
attn_kproj_w=attn_kproj_w,
|
|
attn_vproj_w=attn_vproj_w,
|
|
attn_oproj_w=attn_oproj_w,
|
|
)
|
|
|
|
return attn_layer
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
sequence_lengths: torch.Tensor,
|
|
cos_sin: Tuple[torch.Tensor],
|
|
fd_inter_tensor: FDIntermTensors,
|
|
is_prompts: bool = True,
|
|
is_verifier: bool = False,
|
|
tokens_to_verify: int = None,
|
|
kv_seq_len: int = 0,
|
|
output_tensor: torch.Tensor = None,
|
|
sm_scale: int = None,
|
|
use_cuda_kernel: bool = True,
|
|
cu_seqlens: torch.Tensor = None,
|
|
high_precision: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
|
storing mapping of token_position_id -> block_id.
|
|
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
|
|
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
|
|
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
|
|
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
|
storing intermediate values in flash-decoding.
|
|
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
|
|
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
|
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
|
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
|
|
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
|
"""
|
|
|
|
return NopadLlamaAttention.forward(
|
|
self,
|
|
hidden_states=hidden_states,
|
|
block_tables=block_tables,
|
|
k_cache=k_cache,
|
|
v_cache=v_cache,
|
|
sequence_lengths=sequence_lengths,
|
|
cos_sin=cos_sin,
|
|
fd_inter_tensor=fd_inter_tensor,
|
|
is_prompts=is_prompts,
|
|
is_verifier=is_verifier,
|
|
tokens_to_verify=tokens_to_verify,
|
|
kv_seq_len=kv_seq_len,
|
|
output_tensor=output_tensor,
|
|
sm_scale=sm_scale,
|
|
use_cuda_kernel=use_cuda_kernel,
|
|
cu_seqlens=cu_seqlens,
|
|
high_precision=high_precision,
|
|
)
|
|
|
|
|
|
# NOTE This will cause difference as out length increases.
|
|
class NopadBaichuanMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
mlp_gproj_w: torch.Tensor = None,
|
|
mlp_uproj_w: torch.Tensor = None,
|
|
mlp_dproj_w: torch.Tensor = None,
|
|
):
|
|
"""This layer will replace the BaichuanAttention.
|
|
|
|
Args:
|
|
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
|
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
|
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
|
"""
|
|
super().__init__()
|
|
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
|
self.down_proj_weight = mlp_dproj_w
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
|
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
|
|
|
|
Args:
|
|
module (nn.Module): The origin MLP(Baichuan) layer.
|
|
"""
|
|
|
|
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
|
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
|
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
|
|
|
mlp_layer = NopadBaichuanMLP(
|
|
mlp_gproj_w=mlp_gproj_w,
|
|
mlp_uproj_w=mlp_uproj_w,
|
|
mlp_dproj_w=mlp_dproj_w,
|
|
)
|
|
|
|
return mlp_layer
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
|
"""
|
|
hidden_states = hidden_states.expand(2, -1, -1)
|
|
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
|
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
|
return torch.mm(act_out, self.down_proj_weight)
|