mirror of https://github.com/hpcaitech/ColossalAI
[coati] inference supports profanity check (#3295)
parent
ce2cafae76
commit
73b542a124
|
@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
|
|||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
MAX_LEN = 512
|
||||
|
@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
|
|||
@limiter.limit('1/second')
|
||||
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
if prompt_processor.has_censored_words(prompt):
|
||||
return prompt_processor.SAFE_RESPONSE
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
with running_lock:
|
||||
output = model.generate(**inputs, **data.dict(exclude={'history'}))
|
||||
|
@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
|
|||
prompt_len = inputs['input_ids'].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||
return prompt_processor.postprocess_output(out_string)
|
||||
out_string = prompt_processor.postprocess_output(out_string)
|
||||
if prompt_processor.has_censored_words(out_string):
|
||||
return prompt_processor.SAFE_RESPONSE
|
||||
return out_string
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -140,13 +145,19 @@ if __name__ == '__main__':
|
|||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == '4bit':
|
||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)
|
||||
|
||||
if args.profanity_file is not None:
|
||||
censored_words = load_json(args.profanity_file)
|
||||
else:
|
||||
censored_words = []
|
||||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
|
||||
|
||||
if args.quant == '4bit':
|
||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Optional
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
|||
|
||||
|
||||
class ChatPromptProcessor:
|
||||
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
|
||||
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
|
||||
self.tokenizer = tokenizer
|
||||
self.context = context
|
||||
self.max_len = max_len
|
||||
if len(censored_words) > 0:
|
||||
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
|
||||
else:
|
||||
self.censored_pat = None
|
||||
# These will be initialized after the first call of preprocess_prompt()
|
||||
self.context_len: Optional[int] = None
|
||||
self.dialogue_placeholder_len: Optional[int] = None
|
||||
|
@ -172,6 +178,10 @@ class ChatPromptProcessor:
|
|||
output = STOP_PAT.sub('', output)
|
||||
return output.strip()
|
||||
|
||||
def has_censored_words(self, text: str) -> bool:
|
||||
if self.censored_pat is None:
|
||||
return False
|
||||
return self.censored_pat.search(text) is not None
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
|
@ -185,3 +195,7 @@ class LockedIterator:
|
|||
def __next__(self):
|
||||
with self.lock:
|
||||
return next(self.it)
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
Loading…
Reference in New Issue