[Fix]: Update web demo to be self-contained (#624)

pull/635/head
Wenwei Zhang 2024-01-19 11:24:22 +08:00 committed by GitHub
parent 519c7934c4
commit f08a18b9b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 142 additions and 7 deletions

View File

@ -7,17 +7,148 @@ Please refer to these links below for more information:
3. transformers: https://github.com/huggingface/transformers 3. transformers: https://github.com/huggingface/transformers
""" """
from dataclasses import asdict import copy
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
import streamlit as st import streamlit as st
import torch import torch
from tools.transformers.interface import GenerationConfig, generate_interactive from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass
class GenerationConfig:
# this config is used for chat to provide more diversity
max_length: int = 32768
top_p: float = 0.8
temperature: float = 0.8
do_sample: bool = True
repetition_penalty: float = 1.005
@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
def on_btn_click(): def on_btn_click():
del st.session_state.messages del st.session_state.messages
@ -35,7 +166,7 @@ def load_model():
def prepare_generation_config(): def prepare_generation_config():
with st.sidebar: with st.sidebar:
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048) max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01) top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01) temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
st.button("Clear Chat History", on_click=on_btn_click) st.button("Clear Chat History", on_click=on_btn_click)
@ -52,17 +183,21 @@ cur_query_prompt = "[UNUSED_TOKEN_146]user\n{user}[UNUSED_TOKEN_145]\n[UNUSED_TO
def combine_history(prompt): def combine_history(prompt):
messages = st.session_state.messages messages = st.session_state.messages
total_prompt = "" meta_instruction = (
"You are InternLM (书生·浦语), a helpful, honest, and harmless AI assistant developed by Shanghai "
"AI Laboratory (上海人工智能实验室)."
)
total_prompt = f"<s>[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n"
for message in messages: for message in messages:
cur_content = message["content"] cur_content = message["content"]
if message["role"] == "user": if message["role"] == "user":
cur_prompt = user_prompt.replace("{user}", cur_content) cur_prompt = user_prompt.format(user=cur_content)
elif message["role"] == "robot": elif message["role"] == "robot":
cur_prompt = robot_prompt.replace("{robot}", cur_content) cur_prompt = robot_prompt.format(robot=cur_content)
else: else:
raise RuntimeError raise RuntimeError
total_prompt += cur_prompt total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
return total_prompt return total_prompt