2023-03-28 13:20:28 +00:00
import re
2023-03-28 12:25:36 +00:00
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
2023-03-28 18:14:35 +00:00
import json
2023-03-28 12:25:36 +00:00
import torch
import torch.distributed as dist
import torch.nn as nn
from pydantic import BaseModel, Field
from transformers.generation_logits_process import (
except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
if top_k is not None and top_k != 0:
if top_p is not None and top_p < 1.0:
return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
return unfinished_sequences.max() == 0
def sample_streamingly(model: nn.Module,
input_ids: torch.Tensor,
max_generate_tokens: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
yield next_tokens
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and _is_sequence_finished(unfinished_sequences):
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
return model_kwargs
class Dialogue(BaseModel):
instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
response: str = Field(example='')
def _format_dialogue(instruction: str, response: str = ''):
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
2023-03-28 13:20:28 +00:00
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
2023-03-28 12:25:36 +00:00
class ChatPromptProcessor:
2023-03-28 18:14:35 +00:00
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
2023-03-28 12:25:36 +00:00
2023-03-28 18:14:35 +00:00
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
2023-03-28 12:25:36 +00:00
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
2023-03-28 18:14:35 +00:00
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
self.censored_pat = None
2023-03-28 12:25:36 +00:00
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
self.context_len = len(self.tokenizer(self.context)['input_ids'])
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
assert last_dialogue.response == ''
if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
# to avoid truncate placeholder, apply truncate to the original instruction
instruction_truncated = self.tokenizer(last_dialogue.instruction,
max_length=(self.max_len - max_new_tokens - self.context_len -
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
if res_len - cur_len < 0:
res_len -= cur_len
rows.insert(0, text)
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
2023-03-28 13:20:28 +00:00
def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output)
return output.strip()
2023-03-28 18:14:35 +00:00
def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
return False
return self.censored_pat.search(text) is not None
2023-03-28 12:25:36 +00:00
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
2023-03-28 18:14:35 +00:00
def load_json(path: str):
with open(path) as f:
return json.load(f)