ColossalAI/applications/ColossalChat/examples/inference/web_chatbot/utils.py

79 lines
2.4 KiB
Python
Executable File

import copy
import json
from threading import Lock
from typing import List
import jieba
import torch
from coati.dataset.conversation import default_conversation
from pydantic import BaseModel, Field
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs
class Dialogue(BaseModel):
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
response: str = Field(example="")
class ChatPromptProcessor:
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
def __init__(self, censored_words: List[str] = []):
self.censored_words = set([word.lower() for word in censored_words])
self.conv = copy.deepcopy(default_conversation)
def preprocess_prompt(self, history: List[Dialogue]) -> str:
self.conv.clear()
for round in history:
self.conv.append_message(self.conv.roles[0], round.instruction)
if len(round.instruction) > 0:
self.conv.append_message(self.conv.roles[1], round.response)
return self.conv.get_prompt()
def postprocess_output(self, output: str) -> str:
return output.strip()
def has_censored_words(self, text: str) -> bool:
if len(self.censored_words) == 0:
return False
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
def load_json(path: str):
with open(path) as f:
return json.load(f)