diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 9d7c2c0ad..417ee8295 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -26,6 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n{input_text}[/INST]",
+ "baichuan": "{input_text}",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index f6b5a6e79..466f6749b 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -27,6 +27,7 @@ PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
+ "BaichuanForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
new file mode 100644
index 000000000..893d45c1f
--- /dev/null
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -0,0 +1,183 @@
+# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from colossalai.inference.flash_decoding_utils import FDIntermTensors
+from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
+from colossalai.kernel.kernel_loader import InferenceOpsLoader
+from colossalai.logging import get_dist_logger
+
+inference_ops = InferenceOpsLoader().load()
+
+logger = get_dist_logger(__name__)
+
+
+class NopadBaichuanAttention(nn.Module):
+ def __init__(
+ self,
+ config,
+ attn_qproj_w: torch.Tensor = None,
+ attn_kproj_w: torch.Tensor = None,
+ attn_vproj_w: torch.Tensor = None,
+ attn_oproj_w: torch.Tensor = None,
+ ):
+ """This layer will replace the BaichuanAttention.
+
+ Args:
+ config (BaichuanConfig): Holding the Baichuan model config.
+ attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
+ attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
+ attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
+ attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
+ """
+ super().__init__()
+ self.o_proj_weight = attn_oproj_w
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+
+ # Used to adapt llama_base_attn_forward
+ self.num_key_value_heads = self.num_heads
+
+ qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
+ self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
+
+ @staticmethod
+ def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBaichuanAttention":
+ """Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention.
+
+ Args:
+ module (nn.Module): The origin BaichuanAttention layer.
+ """
+
+ config = module.config
+
+ q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((3, module.hidden_size, module.hidden_size))
+
+ attn_qproj_w = q_proj_w.transpose(0, 1)
+ attn_kproj_w = k_proj_w.transpose(0, 1)
+ attn_vproj_w = v_proj_w.transpose(0, 1)
+ attn_oproj_w = module.o_proj.weight.transpose(0, 1)
+
+ attn_layer = NopadBaichuanAttention(
+ config=config,
+ attn_qproj_w=attn_qproj_w,
+ attn_kproj_w=attn_kproj_w,
+ attn_vproj_w=attn_vproj_w,
+ attn_oproj_w=attn_oproj_w,
+ )
+
+ return attn_layer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ block_tables: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ sequence_lengths: torch.Tensor,
+ cos_sin: Tuple[torch.Tensor],
+ fd_inter_tensor: FDIntermTensors,
+ is_prompts: bool = True,
+ is_verifier: bool = False,
+ tokens_to_verify: int = None,
+ kv_seq_len: int = 0,
+ output_tensor: torch.Tensor = None,
+ sm_scale: int = None,
+ use_cuda_kernel: bool = True,
+ cu_seqlens: torch.Tensor = None,
+ high_precision: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Args:
+ hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
+ block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+ storing mapping of token_position_id -> block_id.
+ k_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ v_cache (torch.Tensor): It holds the GPU memory for the key cache.
+ sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence.
+ cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin.
+ fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+ storing intermediate values in flash-decoding.
+ is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+ kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+ output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+ sm_scale (int, optional): Used for flash attention. Defaults to None.
+ use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
+ cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length.
+ high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
+ """
+
+ return NopadLlamaAttention.forward(
+ self,
+ hidden_states=hidden_states,
+ block_tables=block_tables,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sequence_lengths=sequence_lengths,
+ cos_sin=cos_sin,
+ fd_inter_tensor=fd_inter_tensor,
+ is_prompts=is_prompts,
+ is_verifier=is_verifier,
+ tokens_to_verify=tokens_to_verify,
+ kv_seq_len=kv_seq_len,
+ output_tensor=output_tensor,
+ sm_scale=sm_scale,
+ use_cuda_kernel=use_cuda_kernel,
+ cu_seqlens=cu_seqlens,
+ high_precision=high_precision,
+ )
+
+
+# NOTE This will cause difference as out length increases.
+class NopadBaichuanMLP(nn.Module):
+ def __init__(
+ self,
+ mlp_gproj_w: torch.Tensor = None,
+ mlp_uproj_w: torch.Tensor = None,
+ mlp_dproj_w: torch.Tensor = None,
+ ):
+ """This layer will replace the BaichuanAttention.
+
+ Args:
+ mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
+ mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
+ mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
+ """
+ super().__init__()
+ self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
+ self.down_proj_weight = mlp_dproj_w
+
+ @staticmethod
+ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
+ """Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan).
+
+ Args:
+ module (nn.Module): The origin MLP(Baichuan) layer.
+ """
+
+ mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
+ mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
+ mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
+
+ mlp_layer = NopadBaichuanMLP(
+ mlp_gproj_w=mlp_gproj_w,
+ mlp_uproj_w=mlp_uproj_w,
+ mlp_dproj_w=mlp_dproj_w,
+ )
+
+ return mlp_layer
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
+ """
+ hidden_states = hidden_states.expand(2, -1, -1)
+ gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
+ act_out = inference_ops.silu_and_mul(gate_up_proj_out)
+ return torch.mm(act_out, self.down_proj_weight)
diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index 2b14190da..010abc1db 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -479,7 +479,7 @@ class NopadLlamaAttention(LlamaAttention):
return attn_output
-# NOTE This will cause the result to be different from the transformer in some cases.
+# NOTE This will cause difference as out length increases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,
diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py
index 54852751a..fa0395590 100644
--- a/colossalai/inference/modeling/policy/__init__.py
+++ b/colossalai/inference/modeling/policy/__init__.py
@@ -1,9 +1,16 @@
from .glide_llama import GlideLlamaModelPolicy
+from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
+ "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
}
-__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]
+__all__ = [
+ "NoPaddingLlamaModelInferPolicy",
+ "NoPaddingBaichuanModelInferPolicy",
+ "GlideLlamaModelPolicy",
+ "model_polic_map",
+]
diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py
new file mode 100644
index 000000000..64dc40dbc
--- /dev/null
+++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+from torch.nn import Parameter
+
+from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
+from colossalai.inference.modeling.models.nopadding_llama import (
+ llama_causal_lm_forward,
+ llama_decoder_layer_forward,
+ llama_model_forward,
+ llama_rmsnorm_forward,
+)
+from colossalai.inference.utils import init_to_get_rotary
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+
+class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ decoder_attribute_replacement = {
+ "lm_head.weight": Parameter(
+ nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
+ ),
+ }
+ policy["BaichuanForCausalLM"] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
+
+ policy["DecoderLayer"] = ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="mlp",
+ target_module=NopadBaichuanMLP,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn",
+ target_module=NopadBaichuanAttention,
+ ),
+ ]
+ )
+
+ self.append_or_create_method_replacement(
+ description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
+ )
+ self.append_or_create_method_replacement(
+ description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
+ )
+
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.model)
+ return self.model
diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py
index 448a84c6f..8128ce9f3 100644
--- a/examples/inference/benchmark_llama.py
+++ b/examples/inference/benchmark_llama.py
@@ -117,6 +117,7 @@ def benchmark_inference(args):
max_output_len=args.output_len,
prefill_ratio=1.2,
block_size=32,
+ use_cuda_kernel=True,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py
new file mode 100644
index 000000000..5ca67c5be
--- /dev/null
+++ b/tests/test_infer/test_models/test_baichuan.py
@@ -0,0 +1,97 @@
+import os
+import random
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+
+import colossalai
+from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.inference.flash_decoding_utils import FDIntermTensors
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+
+BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def check_inference_engine(use_engine=False, prompt_template=None):
+ setup_seed(20)
+ tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
+ ).cuda()
+ model = model.eval()
+
+ inputs = [
+ "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
+ ]
+
+ output_len = 38
+ do_sample = False
+
+ if use_engine:
+ inference_config = InferenceConfig(
+ max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True
+ )
+ inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
+ assert inference_engine.generation_config.max_new_tokens == output_len
+ inference_engine.add_request(prompts=inputs)
+ assert inference_engine.request_handler._has_waiting()
+ generation_config = GenerationConfig(do_sample=do_sample)
+ outputs = inference_engine.generate(generation_config=generation_config)
+ else:
+ if prompt_template:
+ # apply prompt template
+ inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
+ inputs = inputs.cuda()
+ generation_config = GenerationConfig(
+ do_sample=do_sample,
+ pad_token_id=tokenizer.pad_token_id,
+ max_new_tokens=output_len,
+ )
+ outputs = model.generate(inputs, generation_config=generation_config)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+
+ return outputs
+
+
+@parameterize("prompt_template", [None, "baichuan"])
+def check_output_consistency(prompt_template):
+ cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
+ transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
+
+ for s1, s2 in zip(cai_outputs, transformer_outputs):
+ assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
+
+ # clear singleton flash decoding tensors
+ FDIntermTensors._instances = {}
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ check_output_consistency()
+
+
+@pytest.mark.skipif(
+ not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH),
+ reason="There is no local model address included, please replace this address with a valid one.",
+)
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_inference_engine():
+ spawn(run_dist, 1)
+
+
+if __name__ == "__main__":
+ test_inference_engine()