refactor(tools): move interface.py and import it to web_demo (#195)

* move interface.py and import it to web_demo

* typo
pull/203/head
x54-729 2023-08-14 22:32:29 +08:00 committed by GitHub
parent ccb06a98e4
commit 0600b42c01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 151 deletions

View File

@ -27,7 +27,7 @@ import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from internlm.utils.interface import GenerationConfig, generation_iterator
from tools.transformers.interface import GenerationConfig, generate_interactive
from internlm.utils.timeout import Timeout
@ -115,7 +115,7 @@ class GenericRuntime:
class PALInterface:
"""PAL interface wrap fun:`generation_iterator` to extract and execute
"""PAL interface wrap fun:`generate_interactive` to extract and execute
generated code.
Adapted from https://github.com/reasoning-machines/pal
@ -150,7 +150,7 @@ class PALInterface:
def generate(self, prompt):
# The api will generate response word by word
# we only need the last generation as the final results
for cur_gen in generation_iterator(
for cur_gen in generate_interactive(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging
@ -21,10 +22,10 @@ class GenerationConfig:
@torch.inference_mode()
def generation_iterator(
model: AutoModel,
tokenizer: AutoTokenizer,
prompt: str,
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
@ -37,12 +38,12 @@ def generation_iterator(
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs["input_ids"]
input_ids_seq_length = input_ids.shape[-1]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
eos_token_id = generation_config.eos_token_id
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
@ -58,24 +59,20 @@ def generation_iterator(
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warning(
"Both `max_new_tokens` (={%s}) and `max_length`(="
"{%s}) seem to have been set. `max_new_tokens` will take precedence. "
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
generation_config.max_new_tokens,
generation_config.max_length,
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids"
logger.warning(
"Input length of {%s} is {%s}, but `max_length` is set to"
" {%s}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`.",
input_ids_string,
input_ids_seq_length,
generation_config.max_length,
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
@ -114,7 +111,7 @@ def generation_iterator(
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = next_token_scores.softmax(dim=-1)
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
@ -122,9 +119,11 @@ def generation_iterator(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:

View File

@ -8,7 +8,6 @@ Please refer to these links below for more information:
import streamlit as st
import torch
import torch.nn as nn
from dataclasses import dataclass, asdict
from typing import List, Optional, Callable, Optional
import copy
@ -16,140 +15,16 @@ import warnings
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from tools.transformers.interface import generate_interactive, GenerationConfig
logger = logging.get_logger(__name__)
@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
def on_btn_click():
del st.session_state.messages
@dataclass
class GenerationConfig:
max_length: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
do_sample: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
@st.cache_resource
def load_model():
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()