mirror of https://github.com/hpcaitech/ColossalAI
79 lines
2.4 KiB
Python
Executable File
79 lines
2.4 KiB
Python
Executable File
import copy
|
|
import json
|
|
from threading import Lock
|
|
from typing import List
|
|
|
|
import jieba
|
|
import torch
|
|
from coati.dataset.conversation import default_conversation
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
|
if "past_key_values" in outputs:
|
|
model_kwargs["past"] = outputs["past_key_values"]
|
|
else:
|
|
model_kwargs["past"] = None
|
|
|
|
# update token_type_ids with last value
|
|
if "token_type_ids" in model_kwargs:
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
|
|
# update attention mask
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
)
|
|
|
|
return model_kwargs
|
|
|
|
|
|
class Dialogue(BaseModel):
|
|
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
|
|
response: str = Field(example="")
|
|
|
|
|
|
class ChatPromptProcessor:
|
|
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
|
|
|
|
def __init__(self, censored_words: List[str] = []):
|
|
self.censored_words = set([word.lower() for word in censored_words])
|
|
self.conv = copy.deepcopy(default_conversation)
|
|
|
|
def preprocess_prompt(self, history: List[Dialogue]) -> str:
|
|
self.conv.clear()
|
|
for round in history:
|
|
self.conv.append_message(self.conv.roles[0], round.instruction)
|
|
if len(round.instruction) > 0:
|
|
self.conv.append_message(self.conv.roles[1], round.response)
|
|
return self.conv.get_prompt()
|
|
|
|
def postprocess_output(self, output: str) -> str:
|
|
return output.strip()
|
|
|
|
def has_censored_words(self, text: str) -> bool:
|
|
if len(self.censored_words) == 0:
|
|
return False
|
|
intersection = set(jieba.cut(text.lower())) & self.censored_words
|
|
return len(intersection) > 0
|
|
|
|
|
|
class LockedIterator:
|
|
def __init__(self, it, lock: Lock) -> None:
|
|
self.lock = lock
|
|
self.it = iter(it)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
with self.lock:
|
|
return next(self.it)
|
|
|
|
|
|
def load_json(path: str):
|
|
with open(path) as f:
|
|
return json.load(f)
|