From 4905b21b945bd3a5ac72c0ece112392cbc5e7096 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 28 Mar 2023 21:20:28 +0800 Subject: [PATCH] [coati] fix inference output (#3285) * [coati] fix inference requirements * [coati] add output postprocess * [coati] update inference readme * [coati] fix inference requirements --- applications/Chat/inference/README.md | 6 ++++++ applications/Chat/inference/requirements.txt | 4 +++- applications/Chat/inference/server.py | 4 ++-- applications/Chat/inference/utils.py | 8 ++++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md index 3fb330748..6c23bc73c 100644 --- a/applications/Chat/inference/README.md +++ b/applications/Chat/inference/README.md @@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar | LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 | | LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada | +## General setup + +```shell +pip install -r requirements.txt +``` + ## 8-bit setup 8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source. diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt index 67a9874e5..7b0ac18a3 100644 --- a/applications/Chat/inference/requirements.txt +++ b/applications/Chat/inference/requirements.txt @@ -1,5 +1,5 @@ fastapi -locustio +locust numpy pydantic safetensors @@ -8,3 +8,5 @@ sse_starlette torch uvicorn git+https://github.com/huggingface/transformers +accelerate +bitsandbytes diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index 46a8b9a05..bfcd89264 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -17,7 +17,7 @@ from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' -MAX_LEN = 2048 +MAX_LEN = 512 running_lock = Lock() @@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): prompt_len = inputs['input_ids'].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) - return out_string.lstrip() + return prompt_processor.postprocess_output(out_string) if __name__ == '__main__': diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index 3d04aa57d..a01983de3 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -1,3 +1,4 @@ +import re from threading import Lock from typing import Any, Callable, Generator, List, Optional @@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''): return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' +STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) + + class ChatPromptProcessor: def __init__(self, tokenizer, context: str, max_len: int = 2048): @@ -164,6 +168,10 @@ class ChatPromptProcessor: prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) return prompt + def postprocess_output(self, output: str) -> str: + output = STOP_PAT.sub('', output) + return output.strip() + class LockedIterator: