mirror of https://github.com/hpcaitech/ColossalAI
[coati] fix inference profanity check (#3299)
parent
5134ad5d1a
commit
62f7156131
|
@ -10,3 +10,4 @@ uvicorn
|
|||
git+https://github.com/huggingface/transformers
|
||||
accelerate
|
||||
bitsandbytes
|
||||
jieba
|
|
@ -2,6 +2,7 @@ import re
|
|||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Optional
|
||||
import json
|
||||
import jieba
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -130,10 +131,7 @@ class ChatPromptProcessor:
|
|||
self.tokenizer = tokenizer
|
||||
self.context = context
|
||||
self.max_len = max_len
|
||||
if len(censored_words) > 0:
|
||||
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
|
||||
else:
|
||||
self.censored_pat = None
|
||||
self.censored_words = set([word.lower() for word in censored_words])
|
||||
# These will be initialized after the first call of preprocess_prompt()
|
||||
self.context_len: Optional[int] = None
|
||||
self.dialogue_placeholder_len: Optional[int] = None
|
||||
|
@ -179,9 +177,10 @@ class ChatPromptProcessor:
|
|||
return output.strip()
|
||||
|
||||
def has_censored_words(self, text: str) -> bool:
|
||||
if self.censored_pat is None:
|
||||
if len(self.censored_words) == 0:
|
||||
return False
|
||||
return self.censored_pat.search(text) is not None
|
||||
intersection = set(jieba.cut(text.lower())) & self.censored_words
|
||||
return len(intersection) > 0
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
|
|
Loading…
Reference in New Issue