mirror of https://github.com/hpcaitech/ColossalAI
[inference] added inference template (#5375)
parent
8106ede07f
commit
58740b5f68
|
@ -23,6 +23,12 @@ _DTYPE_MAPPING = {
|
|||
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
_DEFAULT_PROMPT_TEMPLATES = {
|
||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""The inference configuration.
|
||||
|
@ -44,6 +50,7 @@ class InferenceConfig:
|
|||
pad_input: Whether to pad all inputs to the max length.
|
||||
quant_mode (Optional[str]): Quantization mode.
|
||||
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
|
||||
"""
|
||||
|
||||
micro_batch_size: int = 1
|
||||
|
@ -62,6 +69,7 @@ class InferenceConfig:
|
|||
pad_input: bool = False
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_config()
|
||||
|
@ -85,3 +93,15 @@ class InferenceConfig:
|
|||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
||||
if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES:
|
||||
self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template]
|
||||
else:
|
||||
# make sure the template can be formatted with input_text
|
||||
assert (
|
||||
"{input_text}" in self.prompt_template
|
||||
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
|
||||
|
|
|
@ -170,6 +170,26 @@ class InferenceEngine:
|
|||
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
|
@ -185,6 +205,10 @@ class InferenceEngine:
|
|||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if prompts_token_ids is None:
|
||||
|
|
|
@ -6,9 +6,10 @@ import torch
|
|||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
@ -18,7 +19,7 @@ def setup_seed(seed):
|
|||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(test_cai=False):
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = (
|
||||
|
@ -43,14 +44,17 @@ def check_inference_engine(test_cai=False):
|
|||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
if test_cai:
|
||||
inference_config = InferenceConfig(max_output_len=output_len)
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
# apply prompt template
|
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
|
@ -68,14 +72,22 @@ def check_inference_engine(test_cai=False):
|
|||
return outputs
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
cai_outputs = check_inference_engine(True)
|
||||
transformer_outputs = check_inference_engine(False)
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
def check_output_consistency(prompt_template):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
||||
# clear singleton flash decoding tensors
|
||||
FDIntermTensors._instances = {}
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue