[coati] fix inference profanity check (#3299)

pull/3300/head
ver217 2023-03-29 04:26:35 +08:00 committed by GitHub
parent 5134ad5d1a
commit 62f7156131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -10,3 +10,4 @@ uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
jieba

View File

@ -2,6 +2,7 @@ import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import json
import jieba
import torch
import torch.distributed as dist
@ -130,10 +131,7 @@ class ChatPromptProcessor:
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
else:
self.censored_pat = None
self.censored_words = set([word.lower() for word in censored_words])
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
@ -179,9 +177,10 @@ class ChatPromptProcessor:
return output.strip()
def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
if len(self.censored_words) == 0:
return False
return self.censored_pat.search(text) is not None
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
class LockedIterator: