ColossalAI/applications/Chat/inference/utils.py

210 lines
7.8 KiB
Python
Raw Normal View History

[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
import json
import re
2023-03-28 12:25:36 +00:00
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
import jieba
2023-03-28 12:25:36 +00:00
import torch
import torch.distributed as dist
import torch.nn as nn
from pydantic import BaseModel, Field
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> LogitsProcessorList:
2023-03-28 12:25:36 +00:00
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
processor_list.append(TopKLogitsWarper(top_k))
if top_p is not None and top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
dist.all_reduce(unfinished_sequences)
return unfinished_sequences.max() == 0
def sample_streamingly(
model: nn.Module,
input_ids: torch.Tensor,
max_generate_tokens: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> Generator:
2023-03-28 12:25:36 +00:00
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
model_inputs = (
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
2023-03-28 12:25:36 +00:00
outputs = model(**model_inputs)
next_token_logits = outputs["logits"][:, -1, :]
2023-03-28 12:25:36 +00:00
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
yield next_tokens
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
2023-03-28 12:25:36 +00:00
return model_kwargs
class Dialogue(BaseModel):
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
response: str = Field(example="")
2023-03-28 12:25:36 +00:00
def _format_dialogue(instruction: str, response: str = ""):
return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
2023-03-28 12:25:36 +00:00
STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
2023-03-28 12:25:36 +00:00
class ChatPromptProcessor:
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
2023-03-28 12:25:36 +00:00
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
2023-03-28 12:25:36 +00:00
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
self.censored_words = set([word.lower() for word in censored_words])
2023-03-28 12:25:36 +00:00
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
self.context_len = len(self.tokenizer(self.context)["input_ids"])
2023-03-28 12:25:36 +00:00
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
)
2023-03-28 12:25:36 +00:00
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
assert last_dialogue.response == ""
if (
len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ max_new_tokens
+ self.context_len
>= self.max_len
):
2023-03-28 12:25:36 +00:00
# to avoid truncate placeholder, apply truncate to the original instruction
instruction_truncated = self.tokenizer(
last_dialogue.instruction,
add_special_tokens=False,
truncation=True,
max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
)["input_ids"]
2023-03-28 12:25:36 +00:00
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
2023-03-28 12:25:36 +00:00
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
2023-03-28 12:25:36 +00:00
if res_len - cur_len < 0:
break
res_len -= cur_len
rows.insert(0, text)
prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
2023-03-28 12:25:36 +00:00
return prompt
def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub("", output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
if len(self.censored_words) == 0:
return False
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
2023-03-28 12:25:36 +00:00
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
2023-03-28 12:25:36 +00:00
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
def load_json(path: str):
with open(path) as f:
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
return json.load(f)