mirror of https://github.com/hpcaitech/ColossalAI
[coati] fix inference output (#3285)
* [coati] fix inference requirements * [coati] add output postprocess * [coati] update inference readme * [coati] fix inference requirementspull/3286/head
parent
bb6196e71a
commit
4905b21b94
|
@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
|
|||
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
|
||||
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
|
||||
|
||||
## General setup
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 8-bit setup
|
||||
|
||||
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
fastapi
|
||||
locustio
|
||||
locust
|
||||
numpy
|
||||
pydantic
|
||||
safetensors
|
||||
|
@ -8,3 +8,5 @@ sse_starlette
|
|||
torch
|
||||
uvicorn
|
||||
git+https://github.com/huggingface/transformers
|
||||
accelerate
|
||||
bitsandbytes
|
||||
|
|
|
@ -17,7 +17,7 @@ from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
|||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
MAX_LEN = 2048
|
||||
MAX_LEN = 512
|
||||
running_lock = Lock()
|
||||
|
||||
|
||||
|
@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
|
|||
prompt_len = inputs['input_ids'].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||
return out_string.lstrip()
|
||||
return prompt_processor.postprocess_output(out_string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Optional
|
||||
|
||||
|
@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
|
|||
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
||||
|
||||
|
||||
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
||||
|
||||
|
||||
class ChatPromptProcessor:
|
||||
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
||||
|
@ -164,6 +168,10 @@ class ChatPromptProcessor:
|
|||
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
||||
return prompt
|
||||
|
||||
def postprocess_output(self, output: str) -> str:
|
||||
output = STOP_PAT.sub('', output)
|
||||
return output.strip()
|
||||
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
|
|
Loading…
Reference in New Issue