diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt index 7b0ac18a3..511fe1a4f 100644 --- a/applications/Chat/inference/requirements.txt +++ b/applications/Chat/inference/requirements.txt @@ -10,3 +10,4 @@ uvicorn git+https://github.com/huggingface/transformers accelerate bitsandbytes +jieba \ No newline at end of file diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index 1bb0e82ba..37944be70 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -2,6 +2,7 @@ import re from threading import Lock from typing import Any, Callable, Generator, List, Optional import json +import jieba import torch import torch.distributed as dist @@ -130,10 +131,7 @@ class ChatPromptProcessor: self.tokenizer = tokenizer self.context = context self.max_len = max_len - if len(censored_words) > 0: - self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I) - else: - self.censored_pat = None + self.censored_words = set([word.lower() for word in censored_words]) # These will be initialized after the first call of preprocess_prompt() self.context_len: Optional[int] = None self.dialogue_placeholder_len: Optional[int] = None @@ -179,9 +177,10 @@ class ChatPromptProcessor: return output.strip() def has_censored_words(self, text: str) -> bool: - if self.censored_pat is None: + if len(self.censored_words) == 0: return False - return self.censored_pat.search(text) is not None + intersection = set(jieba.cut(text.lower())) & self.censored_words + return len(intersection) > 0 class LockedIterator: