From 58740b5f6872bc5a26dbf7c3112b86a1b66c083a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 7 Feb 2024 17:11:43 +0800 Subject: [PATCH] [inference] added inference template (#5375) --- colossalai/inference/config.py | 20 +++++++++++++++ colossalai/inference/core/engine.py | 24 ++++++++++++++++++ tests/test_infer/test_inference_engine.py | 30 ++++++++++++++++------- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 6923d63e3..613afcacd 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -23,6 +23,12 @@ _DTYPE_MAPPING = { _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32] +_DEFAULT_PROMPT_TEMPLATES = { + "llama": "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n{input_text}[/INST]", + "vicuna": "USER: {input_text}\n\nASSISTANT: ", +} + + @dataclass class InferenceConfig: """The inference configuration. @@ -44,6 +50,7 @@ class InferenceConfig: pad_input: Whether to pad all inputs to the max length. quant_mode (Optional[str]): Quantization mode. revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use. + prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text. """ micro_batch_size: int = 1 @@ -62,6 +69,7 @@ class InferenceConfig: pad_input: bool = False quant_mode: Optional[str] = None revision: Optional[str] = None + prompt_template: Optional[str] = None def __post_init__(self): self._verify_config() @@ -85,3 +93,15 @@ class InferenceConfig: assert ( self.tp_size * self.pp_size == dist.get_world_size() ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" + + # check prompt template + if self.prompt_template is None: + return + + if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES: + self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template] + else: + # make sure the template can be formatted with input_text + assert ( + "{input_text}" in self.prompt_template + ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 553c89018..d97d70ad5 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -170,6 +170,26 @@ class InferenceEngine: return output_str + @property + def has_prompt_template(self) -> bool: + """ """ + return self.inference_config.prompt_template is not None + + def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: + """ + This method will format the input prompt according to the prompt template given to the InferenceConfig. + """ + assert ( + self.has_prompt_template + ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." + + if isinstance(prompts, (list, tuple)): + return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] + elif isinstance(prompts, str): + return self.inference_config.rompt_template.format(input_text=prompts) + else: + raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + def add_request( self, requests_id: List[int] = None, @@ -185,6 +205,10 @@ class InferenceEngine: prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. """ + # apply the prompt template to the input prompts + if self.has_prompt_template and prompts is not None: + prompts = self.format_prompt(prompts) + block_size = self.inference_config.block_size if prompts_token_ids is None: diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 8c8e864b0..2bc6d5436 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -6,9 +6,10 @@ import torch from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM import colossalai -from colossalai.inference.config import InferenceConfig +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig from colossalai.inference.core.engine import InferenceEngine -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def setup_seed(seed): @@ -18,7 +19,7 @@ def setup_seed(seed): random.seed(seed) -def check_inference_engine(test_cai=False): +def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = ( @@ -43,14 +44,17 @@ def check_inference_engine(test_cai=False): top_p = 0.5 top_k = 50 - if test_cai: - inference_config = InferenceConfig(max_output_len=output_len) + if use_engine: + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) outputs = inference_engine.generate(generation_config=generation_config) else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] @@ -68,14 +72,22 @@ def check_inference_engine(test_cai=False): return outputs -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") - cai_outputs = check_inference_engine(True) - transformer_outputs = check_inference_engine(False) +@parameterize("prompt_template", [None, "llama"]) +def check_output_consistency(prompt_template): + cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template) + transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template) for s1, s2 in zip(cai_outputs, transformer_outputs): assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}" + # clear singleton flash decoding tensors + FDIntermTensors._instances = {} + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_output_consistency() + @pytest.mark.dist @rerun_if_address_is_in_use()