[inference/model]Adapted to the baichuan2-7B model (#5591)

* Adapted to the baichuan2-7B model

* modified according to the review comments.

* Modified the method of obtaining random weights.

* modified according to the review comments.

* change mlp layewr 'NOTE'
pull/5563/head
yuehuayingxueluo 2024-04-15 16:53:02 +08:00 committed by GitHub
parent d4cb023b62
commit 56b222eff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 354 additions and 2 deletions

View File

@ -26,6 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"baichuan": "<reserved_106>{input_text}<reserved_107>",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}

View File

@ -27,6 +27,7 @@ PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]

View File

@ -0,0 +1,183 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
from typing import Optional, Tuple
import torch
import torch.nn as nn
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.logging import get_dist_logger
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
class NopadBaichuanAttention(nn.Module):
def __init__(
self,
config,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
):
"""This layer will replace the BaichuanAttention.
Args:
config (BaichuanConfig): Holding the Baichuan model config.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
"""
super().__init__()
self.o_proj_weight = attn_oproj_w
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
# Used to adapt llama_base_attn_forward
self.num_key_value_heads = self.num_heads
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
"""Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
Args:
module (nn.Module): The origin BaichuanAttention layer.
"""
config = module.config
q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
attn_qproj_w = q_proj_w.transpose(0, 1)
attn_kproj_w = k_proj_w.transpose(0, 1)
attn_vproj_w = v_proj_w.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
attn_layer = NopadBaichuanAttention(
config=config,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)
return attn_layer
def forward(
self,
hidden_states: torch.Tensor,
block_tables: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
sequence_lengths: torch.Tensor,
cos_sin: Tuple[torch.Tensor],
fd_inter_tensor: FDIntermTensors,
is_prompts: bool = True,
is_verifier: bool = False,
tokens_to_verify: int = None,
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id.
k_cache (torch.Tensor): It holds the GPU memory for the key cache.
v_cache (torch.Tensor): It holds the GPU memory for the key cache.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
return NopadLlamaAttention.forward(
self,
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
)
# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):
def __init__(
self,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the BaichuanAttention.
Args:
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__()
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
self.down_proj_weight = mlp_dproj_w
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
"""Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
Args:
module (nn.Module): The origin MLP(Baichuan) layer.
"""
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
mlp_layer = NopadBaichuanMLP(
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
)
return mlp_layer
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
"""
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
return torch.mm(act_out, self.down_proj_weight)

View File

@ -479,7 +479,7 @@ class NopadLlamaAttention(LlamaAttention):
return attn_output
# NOTE This will cause the result to be different from the transformer in some cases.
# NOTE This will cause difference as out length increases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,

View File

@ -1,9 +1,16 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
}
__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
"model_polic_map",
]

View File

@ -0,0 +1,62 @@
import torch.nn as nn
from torch.nn import Parameter
from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"lm_head.weight": Parameter(
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
),
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
policy["DecoderLayer"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
)
self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
)
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)
self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model

View File

@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":

View File

@ -0,0 +1,97 @@
import os
import random
import numpy as np
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
]
output_len = 38
do_sample = False
if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
@parameterize("prompt_template", [None, "baichuan"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
# clear singleton flash decoding tensors
FDIntermTensors._instances = {}
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
@pytest.mark.skipif(
not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
reason="There is no local model address included, please replace this address with a valid one.",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
spawn(run_dist, 1)
if __name__ == "__main__":
test_inference_engine()