[coati] inference supports profanity check (#3295)

pull/3296/head^2
ver217 2023-03-29 02:14:35 +08:00 committed by GitHub
parent ce2cafae76
commit 73b542a124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 4 deletions

View File

@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512
@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
@limiter.limit('1/second')
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'}))
@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
return prompt_processor.postprocess_output(out_string)
out_string = prompt_processor.postprocess_output(out_string)
if prompt_processor.has_censored_words(out_string):
return prompt_processor.SAFE_RESPONSE
return out_string
if __name__ == '__main__':
@ -140,13 +145,19 @@ if __name__ == '__main__':
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
args = parser.parse_args()
if args.quant == '4bit':
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)
if args.profanity_file is not None:
censored_words = load_json(args.profanity_file)
else:
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)

View File

@ -1,6 +1,7 @@
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import json
import torch
import torch.distributed as dist
@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
def __init__(self, tokenizer, context: str, max_len: int = 2048):
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
else:
self.censored_pat = None
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
@ -172,6 +178,10 @@ class ChatPromptProcessor:
output = STOP_PAT.sub('', output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
return False
return self.censored_pat.search(text) is not None
class LockedIterator:
@ -185,3 +195,7 @@ class LockedIterator:
def __next__(self):
with self.lock:
return next(self.it)
def load_json(path: str):
with open(path) as f:
return json.load(f)