mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add stream chat for chat version * remove os.system clear * modify function namepull/5430/head
Camille Zhong
9 months ago
committed by
GitHub
2 changed files with 302 additions and 0 deletions
@ -0,0 +1,247 @@
|
||||
from copy import deepcopy |
||||
from typing import Optional, List, Dict, Tuple, Callable, Any |
||||
|
||||
import torch |
||||
from torch import nn |
||||
|
||||
from transformers import PreTrainedTokenizer |
||||
from transformers.utils import logging |
||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
||||
|
||||
logger = logging.get_logger(__name__) |
||||
|
||||
|
||||
def get_prompt_template( |
||||
input_query:str, |
||||
history:List[Dict]= None, |
||||
roles:list = ["", "Human", "Assistant"], |
||||
) -> str: |
||||
""" |
||||
Generates a prompt template for chat models based on input and history. |
||||
|
||||
Args: |
||||
input_query (str): User's current input query. |
||||
history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'. |
||||
roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"]. |
||||
|
||||
Returns: |
||||
str: A formatted prompt including the input query and history. |
||||
""" |
||||
prompt = "" |
||||
if history is None: |
||||
new_history = [] |
||||
else: |
||||
new_history = deepcopy(history) |
||||
|
||||
new_history.append({"role": roles[1], "message": input_query.strip()}) |
||||
new_history.append({"role": roles[2], "message": None}) |
||||
|
||||
for _, item in enumerate(new_history): |
||||
role = item.get("role") |
||||
message = item.get("message") |
||||
if role == roles[0]: |
||||
prompt += f"<s>{message}\n\n" |
||||
else: |
||||
if message: |
||||
prompt += f"{role}: <s>{message}</s>" |
||||
else: |
||||
prompt += f"{role}: <s>" |
||||
return prompt |
||||
|
||||
@torch.inference_mode() |
||||
def streaming_chat( |
||||
model: Any, |
||||
tokenizer: PreTrainedTokenizer, |
||||
input_query: str, |
||||
history: List[Dict] = None, |
||||
roles: list = ["", "Human", "Assistant"], |
||||
past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, |
||||
temperature: float = 0.8, |
||||
top_p: float = 0.95, |
||||
top_k: int = 50, |
||||
do_sample: bool = True, |
||||
length_penalty: float = 1.2, |
||||
max_new_tokens: int = 512, |
||||
logits_processor: LogitsProcessorList = None, |
||||
return_past_key_values: bool = False, |
||||
**kwargs, |
||||
): |
||||
""" |
||||
Streaming chat responses generation with a given model and tokenizer. |
||||
|
||||
Args: |
||||
model (Any): The language model to generate responses. |
||||
tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses. |
||||
input_query (str): The current user input to respond to. |
||||
history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'. |
||||
roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"]. |
||||
past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding. |
||||
temperature (float): The temperature value for token sampling, defaults to 0.8. |
||||
top_p (float): Nucleus sampling probability threshold, defaults to 0.95. |
||||
top_k (int): Top-K filtering threshold, defaults to 50. |
||||
do_sample (bool): Whether to sample responses, defaults to True. |
||||
length_penalty (float): Penalty for response length, defaults to 1.2. |
||||
max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512. |
||||
logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None. |
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. |
||||
**kwargs: Additional keyword arguments for generation. |
||||
|
||||
Yields: |
||||
Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and |
||||
optionally the updated past key values if `return_past_key_values` is True. |
||||
|
||||
Ensures padding is on the left side for the tokenizer. |
||||
""" |
||||
assert tokenizer.padding_side == "left", "Current generation only supports left padding." |
||||
if history is None: |
||||
history = [] |
||||
if logits_processor is None: |
||||
logits_processor = LogitsProcessorList() |
||||
|
||||
generation_kwargs = { |
||||
'temperature': temperature, |
||||
'top_p': top_p, |
||||
'top_k': top_k, |
||||
'do_sample': do_sample, |
||||
'max_new_tokens': max_new_tokens, |
||||
'length_penalty': length_penalty, |
||||
'use_cache': True, |
||||
**kwargs |
||||
} |
||||
|
||||
prompt_str = get_prompt_template(input_query, history=history, roles=roles) |
||||
|
||||
eos_token_id = [tokenizer.eos_token_id] |
||||
inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device) |
||||
history.append({"role": roles[1], "message": input_query.strip()}) |
||||
history.append({"role": roles[2], "message": None}) |
||||
|
||||
for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, |
||||
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, |
||||
**generation_kwargs): |
||||
if return_past_key_values: |
||||
outputs, past_key_values = outputs |
||||
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] |
||||
response = tokenizer.decode(outputs) |
||||
|
||||
history[-1]["message"] = response.strip() |
||||
if return_past_key_values: |
||||
yield response, history, past_key_values |
||||
else: |
||||
yield response, history |
||||
|
||||
|
||||
@torch.inference_mode() |
||||
def stream_generate( |
||||
model: Any, |
||||
input_ids: torch.Tensor, |
||||
generation_config: Optional[GenerationConfig] = None, |
||||
logits_processor: Optional[LogitsProcessorList] = None, |
||||
stopping_criteria: Optional[StoppingCriteriaList] = None, |
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
||||
return_past_key_values: bool = False, |
||||
**kwargs, |
||||
): |
||||
""" |
||||
Generates sequences of token ids using the specified model and generation parameters. |
||||
Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py |
||||
|
||||
Args: |
||||
model (Any): The model used for generating sequences of token ids. |
||||
input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. |
||||
generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call. |
||||
logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments |
||||
and generation config. |
||||
stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments |
||||
and a generation config. |
||||
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation. |
||||
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. |
||||
**kwargs: Additional parameters for model generation. |
||||
|
||||
Yields: |
||||
torch.Tensor: The generated token IDs, updated after each generation step. |
||||
Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False. |
||||
""" |
||||
input_ids_len = input_ids.size(1) |
||||
|
||||
if generation_config is None: |
||||
generation_config = model.generation_config |
||||
generation_config = deepcopy(generation_config) |
||||
model_kwargs = generation_config.update(**kwargs) |
||||
|
||||
eos_token_id = generation_config.eos_token_id |
||||
if isinstance(eos_token_id, int): |
||||
eos_token_id = [eos_token_id] |
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
||||
|
||||
if generation_config.max_new_tokens is not None: |
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_len |
||||
|
||||
if input_ids_len >= generation_config.max_length: |
||||
input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids" |
||||
logger.warning( |
||||
f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" |
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
||||
" increasing `max_new_tokens`." |
||||
) |
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
||||
|
||||
# prepare distribution pre_processing samplers |
||||
logits_processor = model._get_logits_processor( |
||||
generation_config=generation_config, |
||||
input_ids_seq_length=input_ids_len, |
||||
encoder_input_ids=input_ids, |
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
||||
logits_processor=logits_processor, |
||||
) |
||||
|
||||
# prepare stopping criteria |
||||
stopping_criteria = model._get_stopping_criteria( |
||||
generation_config=generation_config, stopping_criteria=stopping_criteria |
||||
) |
||||
|
||||
logits_warper = model._get_logits_warper(generation_config) |
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
||||
scores = None |
||||
|
||||
while True: |
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
||||
# forward pass to get next token |
||||
outputs = model( |
||||
**model_inputs, |
||||
return_dict=True, |
||||
output_attentions=False, |
||||
output_hidden_states=False, |
||||
) |
||||
|
||||
# NOTE: this is correct only in left padding mode |
||||
# pre-process distribution |
||||
next_token_logits = outputs.logits[:, -1, :] |
||||
next_token_scores = logits_processor(input_ids, next_token_logits) |
||||
next_token_scores = logits_warper(input_ids, next_token_scores) |
||||
|
||||
# sample |
||||
probs = nn.functional.softmax(next_token_scores, dim=-1) |
||||
if generation_config.do_sample: |
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
||||
else: |
||||
next_tokens = torch.argmax(probs, dim=-1) |
||||
|
||||
# update generated ids, model inputs, and length for next step |
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
||||
model_kwargs = model._update_model_kwargs_for_generation( |
||||
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder |
||||
) |
||||
unfinished_sequences = unfinished_sequences.mul( |
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
||||
) |
||||
|
||||
if return_past_key_values: |
||||
yield input_ids, outputs.past_key_values |
||||
else: |
||||
yield input_ids |
||||
# stop when each sentence is finished, or if exceed the maximum length |
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
||||
break |
@ -0,0 +1,55 @@
|
||||
import os |
||||
import argparse |
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM |
||||
from colossal_llama2.utils.stream_chat_patch import streaming_chat |
||||
|
||||
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." |
||||
|
||||
def main(args): |
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() |
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) |
||||
|
||||
past_key_values, history = None, [] |
||||
roles = ["", "Human", "Assistant"] |
||||
|
||||
history = [] |
||||
history.append({"role": roles[0], "message": SYSTEM}) |
||||
|
||||
while True: |
||||
input_query = input(f"\n{roles[1]}: ") |
||||
if input_query.strip() == "exit": |
||||
break |
||||
if input_query.strip() == "clear": |
||||
past_key_values, history = None, [] |
||||
continue |
||||
|
||||
print(f"\n{roles[2]}: ", end="") |
||||
gen_len = 0 |
||||
for response, history, past_key_values in streaming_chat( |
||||
model, tokenizer, input_query, history=history, roles=roles, |
||||
temperature = args.temperature, |
||||
top_p = args.top_p, |
||||
top_k = args.top_k, |
||||
do_sample = args.do_sample, |
||||
length_penalty = args.length_penalty, |
||||
max_new_tokens = args.max_new_tokens, |
||||
past_key_values=past_key_values, |
||||
return_past_key_values=True): |
||||
|
||||
output = response[gen_len:] |
||||
print(output, end="", flush=True) |
||||
gen_len = len(response) |
||||
|
||||
if __name__ == "__main__": |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") |
||||
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") |
||||
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") |
||||
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") |
||||
parser.add_argument('--top_k', type=int, default=50, help="set top k value") |
||||
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") |
||||
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") |
||||
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") |
||||
args = parser.parse_args() |
||||
main(args) |
Loading…
Reference in new issue