diff --git a/chat/web_demo.py b/chat/web_demo.py index e2ce03a..55af03f 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -7,17 +7,148 @@ Please refer to these links below for more information: 3. transformers: https://github.com/huggingface/transformers """ -from dataclasses import asdict +import copy +import warnings +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional import streamlit as st import torch -from tools.transformers.interface import GenerationConfig, generate_interactive +from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.utils import logging logger = logging.get_logger(__name__) +@dataclass +class GenerationConfig: + # this config is used for chat to provide more diversity + max_length: int = 32768 + top_p: float = 0.8 + temperature: float = 0.8 + do_sample: bool = True + repetition_penalty: float = 1.005 + + +@torch.inference_mode() +def generate_interactive( + model, + tokenizer, + prompt, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + additional_eos_token_id: Optional[int] = None, + **kwargs, +): + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + input_length = len(inputs["input_ids"][0]) + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs["input_ids"] + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612 + if generation_config is None: + generation_config = model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if additional_eos_token_id is not None: + eos_token_id.append(additional_eos_token_id) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( # pylint: disable=W4902 + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) + + output_token_ids = input_ids[0].cpu().tolist() + output_token_ids = output_token_ids[input_length:] + for each_eos_token_id in eos_token_id: + if output_token_ids[-1] == each_eos_token_id: + output_token_ids = output_token_ids[:-1] + response = tokenizer.decode(output_token_ids) + + yield response + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def on_btn_click(): del st.session_state.messages @@ -35,7 +166,7 @@ def load_model(): def prepare_generation_config(): with st.sidebar: - max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048) + max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768) top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01) temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01) st.button("Clear Chat History", on_click=on_btn_click) @@ -52,17 +183,21 @@ cur_query_prompt = "[UNUSED_TOKEN_146]user\n{user}[UNUSED_TOKEN_145]\n[UNUSED_TO def combine_history(prompt): messages = st.session_state.messages - total_prompt = "" + meta_instruction = ( + "You are InternLM (书生·浦语), a helpful, honest, and harmless AI assistant developed by Shanghai " + "AI Laboratory (上海人工智能实验室)." + ) + total_prompt = f"[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n" for message in messages: cur_content = message["content"] if message["role"] == "user": - cur_prompt = user_prompt.replace("{user}", cur_content) + cur_prompt = user_prompt.format(user=cur_content) elif message["role"] == "robot": - cur_prompt = robot_prompt.replace("{robot}", cur_content) + cur_prompt = robot_prompt.format(robot=cur_content) else: raise RuntimeError total_prompt += cur_prompt - total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) + total_prompt = total_prompt + cur_query_prompt.format(user=prompt) return total_prompt