mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add stream chat for chat version * remove os.system clear * modify function namepull/5430/head
Camille Zhong
9 months ago
committed by
GitHub
2 changed files with 302 additions and 0 deletions
@ -0,0 +1,247 @@ |
|||||||
|
from copy import deepcopy |
||||||
|
from typing import Optional, List, Dict, Tuple, Callable, Any |
||||||
|
|
||||||
|
import torch |
||||||
|
from torch import nn |
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer |
||||||
|
from transformers.utils import logging |
||||||
|
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_template( |
||||||
|
input_query:str, |
||||||
|
history:List[Dict]= None, |
||||||
|
roles:list = ["", "Human", "Assistant"], |
||||||
|
) -> str: |
||||||
|
""" |
||||||
|
Generates a prompt template for chat models based on input and history. |
||||||
|
|
||||||
|
Args: |
||||||
|
input_query (str): User's current input query. |
||||||
|
history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'. |
||||||
|
roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"]. |
||||||
|
|
||||||
|
Returns: |
||||||
|
str: A formatted prompt including the input query and history. |
||||||
|
""" |
||||||
|
prompt = "" |
||||||
|
if history is None: |
||||||
|
new_history = [] |
||||||
|
else: |
||||||
|
new_history = deepcopy(history) |
||||||
|
|
||||||
|
new_history.append({"role": roles[1], "message": input_query.strip()}) |
||||||
|
new_history.append({"role": roles[2], "message": None}) |
||||||
|
|
||||||
|
for _, item in enumerate(new_history): |
||||||
|
role = item.get("role") |
||||||
|
message = item.get("message") |
||||||
|
if role == roles[0]: |
||||||
|
prompt += f"<s>{message}\n\n" |
||||||
|
else: |
||||||
|
if message: |
||||||
|
prompt += f"{role}: <s>{message}</s>" |
||||||
|
else: |
||||||
|
prompt += f"{role}: <s>" |
||||||
|
return prompt |
||||||
|
|
||||||
|
@torch.inference_mode() |
||||||
|
def streaming_chat( |
||||||
|
model: Any, |
||||||
|
tokenizer: PreTrainedTokenizer, |
||||||
|
input_query: str, |
||||||
|
history: List[Dict] = None, |
||||||
|
roles: list = ["", "Human", "Assistant"], |
||||||
|
past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, |
||||||
|
temperature: float = 0.8, |
||||||
|
top_p: float = 0.95, |
||||||
|
top_k: int = 50, |
||||||
|
do_sample: bool = True, |
||||||
|
length_penalty: float = 1.2, |
||||||
|
max_new_tokens: int = 512, |
||||||
|
logits_processor: LogitsProcessorList = None, |
||||||
|
return_past_key_values: bool = False, |
||||||
|
**kwargs, |
||||||
|
): |
||||||
|
""" |
||||||
|
Streaming chat responses generation with a given model and tokenizer. |
||||||
|
|
||||||
|
Args: |
||||||
|
model (Any): The language model to generate responses. |
||||||
|
tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses. |
||||||
|
input_query (str): The current user input to respond to. |
||||||
|
history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'. |
||||||
|
roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"]. |
||||||
|
past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding. |
||||||
|
temperature (float): The temperature value for token sampling, defaults to 0.8. |
||||||
|
top_p (float): Nucleus sampling probability threshold, defaults to 0.95. |
||||||
|
top_k (int): Top-K filtering threshold, defaults to 50. |
||||||
|
do_sample (bool): Whether to sample responses, defaults to True. |
||||||
|
length_penalty (float): Penalty for response length, defaults to 1.2. |
||||||
|
max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512. |
||||||
|
logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None. |
||||||
|
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. |
||||||
|
**kwargs: Additional keyword arguments for generation. |
||||||
|
|
||||||
|
Yields: |
||||||
|
Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and |
||||||
|
optionally the updated past key values if `return_past_key_values` is True. |
||||||
|
|
||||||
|
Ensures padding is on the left side for the tokenizer. |
||||||
|
""" |
||||||
|
assert tokenizer.padding_side == "left", "Current generation only supports left padding." |
||||||
|
if history is None: |
||||||
|
history = [] |
||||||
|
if logits_processor is None: |
||||||
|
logits_processor = LogitsProcessorList() |
||||||
|
|
||||||
|
generation_kwargs = { |
||||||
|
'temperature': temperature, |
||||||
|
'top_p': top_p, |
||||||
|
'top_k': top_k, |
||||||
|
'do_sample': do_sample, |
||||||
|
'max_new_tokens': max_new_tokens, |
||||||
|
'length_penalty': length_penalty, |
||||||
|
'use_cache': True, |
||||||
|
**kwargs |
||||||
|
} |
||||||
|
|
||||||
|
prompt_str = get_prompt_template(input_query, history=history, roles=roles) |
||||||
|
|
||||||
|
eos_token_id = [tokenizer.eos_token_id] |
||||||
|
inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device) |
||||||
|
history.append({"role": roles[1], "message": input_query.strip()}) |
||||||
|
history.append({"role": roles[2], "message": None}) |
||||||
|
|
||||||
|
for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, |
||||||
|
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, |
||||||
|
**generation_kwargs): |
||||||
|
if return_past_key_values: |
||||||
|
outputs, past_key_values = outputs |
||||||
|
|
||||||
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] |
||||||
|
response = tokenizer.decode(outputs) |
||||||
|
|
||||||
|
history[-1]["message"] = response.strip() |
||||||
|
if return_past_key_values: |
||||||
|
yield response, history, past_key_values |
||||||
|
else: |
||||||
|
yield response, history |
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode() |
||||||
|
def stream_generate( |
||||||
|
model: Any, |
||||||
|
input_ids: torch.Tensor, |
||||||
|
generation_config: Optional[GenerationConfig] = None, |
||||||
|
logits_processor: Optional[LogitsProcessorList] = None, |
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
||||||
|
return_past_key_values: bool = False, |
||||||
|
**kwargs, |
||||||
|
): |
||||||
|
""" |
||||||
|
Generates sequences of token ids using the specified model and generation parameters. |
||||||
|
Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py |
||||||
|
|
||||||
|
Args: |
||||||
|
model (Any): The model used for generating sequences of token ids. |
||||||
|
input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. |
||||||
|
generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call. |
||||||
|
logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments |
||||||
|
and generation config. |
||||||
|
stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments |
||||||
|
and a generation config. |
||||||
|
prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation. |
||||||
|
return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. |
||||||
|
**kwargs: Additional parameters for model generation. |
||||||
|
|
||||||
|
Yields: |
||||||
|
torch.Tensor: The generated token IDs, updated after each generation step. |
||||||
|
Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False. |
||||||
|
""" |
||||||
|
input_ids_len = input_ids.size(1) |
||||||
|
|
||||||
|
if generation_config is None: |
||||||
|
generation_config = model.generation_config |
||||||
|
generation_config = deepcopy(generation_config) |
||||||
|
model_kwargs = generation_config.update(**kwargs) |
||||||
|
|
||||||
|
eos_token_id = generation_config.eos_token_id |
||||||
|
if isinstance(eos_token_id, int): |
||||||
|
eos_token_id = [eos_token_id] |
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
||||||
|
|
||||||
|
if generation_config.max_new_tokens is not None: |
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_len |
||||||
|
|
||||||
|
if input_ids_len >= generation_config.max_length: |
||||||
|
input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids" |
||||||
|
logger.warning( |
||||||
|
f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" |
||||||
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
||||||
|
" increasing `max_new_tokens`." |
||||||
|
) |
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
||||||
|
|
||||||
|
# prepare distribution pre_processing samplers |
||||||
|
logits_processor = model._get_logits_processor( |
||||||
|
generation_config=generation_config, |
||||||
|
input_ids_seq_length=input_ids_len, |
||||||
|
encoder_input_ids=input_ids, |
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
||||||
|
logits_processor=logits_processor, |
||||||
|
) |
||||||
|
|
||||||
|
# prepare stopping criteria |
||||||
|
stopping_criteria = model._get_stopping_criteria( |
||||||
|
generation_config=generation_config, stopping_criteria=stopping_criteria |
||||||
|
) |
||||||
|
|
||||||
|
logits_warper = model._get_logits_warper(generation_config) |
||||||
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
||||||
|
scores = None |
||||||
|
|
||||||
|
while True: |
||||||
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
||||||
|
# forward pass to get next token |
||||||
|
outputs = model( |
||||||
|
**model_inputs, |
||||||
|
return_dict=True, |
||||||
|
output_attentions=False, |
||||||
|
output_hidden_states=False, |
||||||
|
) |
||||||
|
|
||||||
|
# NOTE: this is correct only in left padding mode |
||||||
|
# pre-process distribution |
||||||
|
next_token_logits = outputs.logits[:, -1, :] |
||||||
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
||||||
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
||||||
|
|
||||||
|
# sample |
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
||||||
|
if generation_config.do_sample: |
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
||||||
|
else: |
||||||
|
next_tokens = torch.argmax(probs, dim=-1) |
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step |
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
||||||
|
model_kwargs = model._update_model_kwargs_for_generation( |
||||||
|
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder |
||||||
|
) |
||||||
|
unfinished_sequences = unfinished_sequences.mul( |
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
||||||
|
) |
||||||
|
|
||||||
|
if return_past_key_values: |
||||||
|
yield input_ids, outputs.past_key_values |
||||||
|
else: |
||||||
|
yield input_ids |
||||||
|
# stop when each sentence is finished, or if exceed the maximum length |
||||||
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
||||||
|
break |
@ -0,0 +1,55 @@ |
|||||||
|
import os |
||||||
|
import argparse |
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
||||||
|
from colossal_llama2.utils.stream_chat_patch import streaming_chat |
||||||
|
|
||||||
|
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." |
||||||
|
|
||||||
|
def main(args): |
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() |
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) |
||||||
|
|
||||||
|
past_key_values, history = None, [] |
||||||
|
roles = ["", "Human", "Assistant"] |
||||||
|
|
||||||
|
history = [] |
||||||
|
history.append({"role": roles[0], "message": SYSTEM}) |
||||||
|
|
||||||
|
while True: |
||||||
|
input_query = input(f"\n{roles[1]}: ") |
||||||
|
if input_query.strip() == "exit": |
||||||
|
break |
||||||
|
if input_query.strip() == "clear": |
||||||
|
past_key_values, history = None, [] |
||||||
|
continue |
||||||
|
|
||||||
|
print(f"\n{roles[2]}: ", end="") |
||||||
|
gen_len = 0 |
||||||
|
for response, history, past_key_values in streaming_chat( |
||||||
|
model, tokenizer, input_query, history=history, roles=roles, |
||||||
|
temperature = args.temperature, |
||||||
|
top_p = args.top_p, |
||||||
|
top_k = args.top_k, |
||||||
|
do_sample = args.do_sample, |
||||||
|
length_penalty = args.length_penalty, |
||||||
|
max_new_tokens = args.max_new_tokens, |
||||||
|
past_key_values=past_key_values, |
||||||
|
return_past_key_values=True): |
||||||
|
|
||||||
|
output = response[gen_len:] |
||||||
|
print(output, end="", flush=True) |
||||||
|
gen_len = len(response) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
parser = argparse.ArgumentParser() |
||||||
|
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") |
||||||
|
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") |
||||||
|
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") |
||||||
|
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") |
||||||
|
parser.add_argument('--top_k', type=int, default=50, help="set top k value") |
||||||
|
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") |
||||||
|
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") |
||||||
|
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") |
||||||
|
args = parser.parse_args() |
||||||
|
main(args) |
Loading…
Reference in new issue