refactor(tools): move interface.py and import it to web_demo (#195)

* move interface.py and import it to web_demo

* typo
pull/203/head
x54-729 2023-08-14 22:32:29 +08:00 committed by GitHub
parent ccb06a98e4
commit 0600b42c01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 151 deletions

View File

@ -27,7 +27,7 @@ import tqdm
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from internlm.utils.interface import GenerationConfig, generation_iterator from tools.transformers.interface import GenerationConfig, generate_interactive
from internlm.utils.timeout import Timeout from internlm.utils.timeout import Timeout
@ -115,7 +115,7 @@ class GenericRuntime:
class PALInterface: class PALInterface:
"""PAL interface wrap fun:`generation_iterator` to extract and execute """PAL interface wrap fun:`generate_interactive` to extract and execute
generated code. generated code.
Adapted from https://github.com/reasoning-machines/pal Adapted from https://github.com/reasoning-machines/pal
@ -150,7 +150,7 @@ class PALInterface:
def generate(self, prompt): def generate(self, prompt):
# The api will generate response word by word # The api will generate response word by word
# we only need the last generation as the final results # we only need the last generation as the final results
for cur_gen in generation_iterator( for cur_gen in generate_interactive(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
prompt=prompt, prompt=prompt,

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging from transformers.utils import logging
@ -21,10 +22,10 @@ class GenerationConfig:
@torch.inference_mode() @torch.inference_mode()
def generation_iterator( def generate_interactive(
model: AutoModel, model,
tokenizer: AutoTokenizer, tokenizer,
prompt: str, prompt,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
@ -37,12 +38,12 @@ def generation_iterator(
for k, v in inputs.items(): for k, v in inputs.items():
inputs[k] = v.cuda() inputs[k] = v.cuda()
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
input_ids_seq_length = input_ids.shape[-1] batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None: if generation_config is None:
generation_config = model.generation_config generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
eos_token_id = generation_config.eos_token_id bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
if additional_eos_token_id is not None: if additional_eos_token_id is not None:
@ -58,24 +59,20 @@ def generation_iterator(
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length: if not has_default_max_length:
logger.warning( logger.warn(
"Both `max_new_tokens` (={%s}) and `max_length`(=" f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
"{%s}) seem to have been set. `max_new_tokens` will take precedence. " f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
generation_config.max_new_tokens, UserWarning,
generation_config.max_length,
) )
if input_ids_seq_length >= generation_config.max_length: if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids" input_ids_string = "input_ids"
logger.warning( logger.warning(
"Input length of {%s} is {%s}, but `max_length` is set to" f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
" {%s}. This can lead to unexpected behavior. You should consider" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`.", " increasing `max_new_tokens`."
input_ids_string,
input_ids_seq_length,
generation_config.max_length,
) )
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
@ -114,7 +111,7 @@ def generation_iterator(
next_token_scores = logits_warper(input_ids, next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores)
# sample # sample
probs = next_token_scores.softmax(dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample: if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else: else:
@ -122,7 +119,9 @@ def generation_iterator(
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist() output_token_ids = input_ids[0].cpu().tolist()

View File

@ -8,7 +8,6 @@ Please refer to these links below for more information:
import streamlit as st import streamlit as st
import torch import torch
import torch.nn as nn
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from typing import List, Optional, Callable, Optional from typing import List, Optional, Callable, Optional
import copy import copy
@ -16,140 +15,16 @@ import warnings
import logging import logging
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from tools.transformers.interface import generate_interactive, GenerationConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
def on_btn_click(): def on_btn_click():
del st.session_state.messages del st.session_state.messages
@dataclass
class GenerationConfig:
max_length: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
do_sample: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
@st.cache_resource @st.cache_resource
def load_model(): def load_model():
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda() model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()