2023-03-28 13:20:28 +00:00
|
|
|
import re
|
2023-03-28 12:25:36 +00:00
|
|
|
from threading import Lock
|
|
|
|
from typing import Any, Callable, Generator, List, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
try:
|
|
|
|
from transformers.generation_logits_process import (
|
|
|
|
LogitsProcessorList,
|
|
|
|
TemperatureLogitsWarper,
|
|
|
|
TopKLogitsWarper,
|
|
|
|
TopPLogitsWarper,
|
|
|
|
)
|
|
|
|
except ImportError:
|
|
|
|
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_logits_processor(top_k: Optional[int] = None,
|
|
|
|
top_p: Optional[float] = None,
|
|
|
|
temperature: Optional[float] = None) -> LogitsProcessorList:
|
|
|
|
processor_list = LogitsProcessorList()
|
|
|
|
if temperature is not None and temperature != 1.0:
|
|
|
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
|
|
if top_k is not None and top_k != 0:
|
|
|
|
processor_list.append(TopKLogitsWarper(top_k))
|
|
|
|
if top_p is not None and top_p < 1.0:
|
|
|
|
processor_list.append(TopPLogitsWarper(top_p))
|
|
|
|
return processor_list
|
|
|
|
|
|
|
|
|
|
|
|
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
|
|
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
|
|
# consider DP
|
|
|
|
unfinished_sequences = unfinished_sequences.clone()
|
|
|
|
dist.all_reduce(unfinished_sequences)
|
|
|
|
return unfinished_sequences.max() == 0
|
|
|
|
|
|
|
|
|
|
|
|
def sample_streamingly(model: nn.Module,
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
max_generate_tokens: int,
|
|
|
|
early_stopping: bool = False,
|
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
|
top_k: Optional[int] = None,
|
|
|
|
top_p: Optional[float] = None,
|
|
|
|
temperature: Optional[float] = None,
|
|
|
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
|
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
|
|
**model_kwargs) -> Generator:
|
|
|
|
|
|
|
|
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
|
|
|
|
for _ in range(max_generate_tokens):
|
|
|
|
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
|
|
|
'input_ids': input_ids
|
|
|
|
}
|
|
|
|
outputs = model(**model_inputs)
|
|
|
|
|
|
|
|
next_token_logits = outputs['logits'][:, -1, :]
|
|
|
|
# pre-process distribution
|
|
|
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
|
|
|
# sample
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
|
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
|
|
if eos_token_id is not None:
|
|
|
|
if pad_token_id is None:
|
|
|
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
|
|
|
|
yield next_tokens
|
|
|
|
|
|
|
|
# update generated ids, model inputs for next step
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
if update_model_kwargs_fn is not None:
|
|
|
|
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
if eos_token_id is not None:
|
|
|
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
|
|
|
|
|
|
|
# stop when each sentence is finished if early_stopping=True
|
|
|
|
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
|
|
|
if "past_key_values" in outputs:
|
|
|
|
model_kwargs["past"] = outputs["past_key_values"]
|
|
|
|
else:
|
|
|
|
model_kwargs["past"] = None
|
|
|
|
|
|
|
|
# update token_type_ids with last value
|
|
|
|
if "token_type_ids" in model_kwargs:
|
|
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
|
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
|
|
|
|
|
|
# update attention mask
|
|
|
|
if "attention_mask" in model_kwargs:
|
|
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
|
|
|
|
|
|
|
return model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
class Dialogue(BaseModel):
|
|
|
|
instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
|
|
|
|
response: str = Field(example='')
|
|
|
|
|
|
|
|
|
|
|
|
def _format_dialogue(instruction: str, response: str = ''):
|
|
|
|
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
|
|
|
|
|
|
|
|
2023-03-28 13:20:28 +00:00
|
|
|
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
|
|
|
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
class ChatPromptProcessor:
|
|
|
|
|
|
|
|
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.context = context
|
|
|
|
self.max_len = max_len
|
|
|
|
# These will be initialized after the first call of preprocess_prompt()
|
|
|
|
self.context_len: Optional[int] = None
|
|
|
|
self.dialogue_placeholder_len: Optional[int] = None
|
|
|
|
|
|
|
|
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
|
|
|
|
if self.context_len is None:
|
|
|
|
self.context_len = len(self.tokenizer(self.context)['input_ids'])
|
|
|
|
if self.dialogue_placeholder_len is None:
|
|
|
|
self.dialogue_placeholder_len = len(
|
|
|
|
self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
|
|
|
|
prompt = self.context
|
|
|
|
# the last dialogue must be in the prompt
|
|
|
|
last_dialogue = history.pop()
|
|
|
|
# the response of the last dialogue is empty
|
|
|
|
assert last_dialogue.response == ''
|
|
|
|
if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
|
|
|
|
['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
|
|
|
|
# to avoid truncate placeholder, apply truncate to the original instruction
|
|
|
|
instruction_truncated = self.tokenizer(last_dialogue.instruction,
|
|
|
|
add_special_tokens=False,
|
|
|
|
truncation=True,
|
|
|
|
max_length=(self.max_len - max_new_tokens - self.context_len -
|
|
|
|
self.dialogue_placeholder_len))['input_ids']
|
|
|
|
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
|
|
|
|
prompt += _format_dialogue(instruction_truncated)
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
|
|
|
|
|
|
|
|
rows = []
|
|
|
|
for dialogue in history[::-1]:
|
|
|
|
text = _format_dialogue(dialogue.instruction, dialogue.response)
|
|
|
|
cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
|
|
|
|
if res_len - cur_len < 0:
|
|
|
|
break
|
|
|
|
res_len -= cur_len
|
|
|
|
rows.insert(0, text)
|
|
|
|
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
|
|
|
return prompt
|
|
|
|
|
2023-03-28 13:20:28 +00:00
|
|
|
def postprocess_output(self, output: str) -> str:
|
|
|
|
output = STOP_PAT.sub('', output)
|
|
|
|
return output.strip()
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
class LockedIterator:
|
|
|
|
|
|
|
|
def __init__(self, it, lock: Lock) -> None:
|
|
|
|
self.lock = lock
|
|
|
|
self.it = iter(it)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
with self.lock:
|
|
|
|
return next(self.it)
|