mirror of https://github.com/hpcaitech/ColossalAI
[coati] fix inference output (#3285)
* [coati] fix inference requirements * [coati] add output postprocess * [coati] update inference readme * [coati] fix inference requirementspull/3286/head
parent
bb6196e71a
commit
4905b21b94
|
@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
|
||||||
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
|
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
|
||||||
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
|
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
|
||||||
|
|
||||||
|
## General setup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
## 8-bit setup
|
## 8-bit setup
|
||||||
|
|
||||||
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
|
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
fastapi
|
fastapi
|
||||||
locustio
|
locust
|
||||||
numpy
|
numpy
|
||||||
pydantic
|
pydantic
|
||||||
safetensors
|
safetensors
|
||||||
|
@ -8,3 +8,5 @@ sse_starlette
|
||||||
torch
|
torch
|
||||||
uvicorn
|
uvicorn
|
||||||
git+https://github.com/huggingface/transformers
|
git+https://github.com/huggingface/transformers
|
||||||
|
accelerate
|
||||||
|
bitsandbytes
|
||||||
|
|
|
@ -17,7 +17,7 @@ from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
||||||
|
|
||||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||||
MAX_LEN = 2048
|
MAX_LEN = 512
|
||||||
running_lock = Lock()
|
running_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||||
prompt_len = inputs['input_ids'].size(1)
|
prompt_len = inputs['input_ids'].size(1)
|
||||||
response = output[0, prompt_len:]
|
response = output[0, prompt_len:]
|
||||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||||
return out_string.lstrip()
|
return prompt_processor.postprocess_output(out_string)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Callable, Generator, List, Optional
|
from typing import Any, Callable, Generator, List, Optional
|
||||||
|
|
||||||
|
@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
|
||||||
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
||||||
|
|
||||||
|
|
||||||
|
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptProcessor:
|
class ChatPromptProcessor:
|
||||||
|
|
||||||
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
||||||
|
@ -164,6 +168,10 @@ class ChatPromptProcessor:
|
||||||
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def postprocess_output(self, output: str) -> str:
|
||||||
|
output = STOP_PAT.sub('', output)
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
|
||||||
class LockedIterator:
|
class LockedIterator:
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue