Browse Source

[coati] fix inference output (#3285)

* [coati] fix inference requirements

* [coati] add output postprocess

* [coati] update inference readme

* [coati] fix inference requirements
pull/3286/head
ver217 2 years ago committed by GitHub
parent
commit
4905b21b94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      applications/Chat/inference/README.md
  2. 4
      applications/Chat/inference/requirements.txt
  3. 4
      applications/Chat/inference/server.py
  4. 8
      applications/Chat/inference/utils.py

6
applications/Chat/inference/README.md

@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
## General setup
```shell
pip install -r requirements.txt
```
## 8-bit setup
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.

4
applications/Chat/inference/requirements.txt

@ -1,5 +1,5 @@
fastapi
locustio
locust
numpy
pydantic
safetensors
@ -8,3 +8,5 @@ sse_starlette
torch
uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes

4
applications/Chat/inference/server.py

@ -17,7 +17,7 @@ from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 2048
MAX_LEN = 512
running_lock = Lock()
@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
return out_string.lstrip()
return prompt_processor.postprocess_output(out_string)
if __name__ == '__main__':

8
applications/Chat/inference/utils.py

@ -1,3 +1,4 @@
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class ChatPromptProcessor:
def __init__(self, tokenizer, context: str, max_len: int = 2048):
@ -164,6 +168,10 @@ class ChatPromptProcessor:
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output)
return output.strip()
class LockedIterator:

Loading…
Cancel
Save