diff --git a/tools/pal_inference.py b/tools/pal_inference.py index 703d169..a3c0cc2 100644 --- a/tools/pal_inference.py +++ b/tools/pal_inference.py @@ -27,7 +27,7 @@ import tqdm from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from internlm.utils.interface import GenerationConfig, generation_iterator +from tools.transformers.interface import GenerationConfig, generate_interactive from internlm.utils.timeout import Timeout @@ -115,7 +115,7 @@ class GenericRuntime: class PALInterface: - """PAL interface wrap fun:`generation_iterator` to extract and execute + """PAL interface wrap fun:`generate_interactive` to extract and execute generated code. Adapted from https://github.com/reasoning-machines/pal @@ -150,7 +150,7 @@ class PALInterface: def generate(self, prompt): # The api will generate response word by word # we only need the last generation as the final results - for cur_gen in generation_iterator( + for cur_gen in generate_interactive( model=self.model, tokenizer=self.tokenizer, prompt=prompt, diff --git a/internlm/utils/interface.py b/tools/transformers/interface.py similarity index 84% rename from internlm/utils/interface.py rename to tools/transformers/interface.py index 22b3743..1a8a69f 100644 --- a/internlm/utils/interface.py +++ b/tools/transformers/interface.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Callable, List, Optional import torch +from torch import nn from transformers import AutoModel, AutoTokenizer from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.utils import logging @@ -21,10 +22,10 @@ class GenerationConfig: @torch.inference_mode() -def generation_iterator( - model: AutoModel, - tokenizer: AutoTokenizer, - prompt: str, +def generate_interactive( + model, + tokenizer, + prompt, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, @@ -37,12 +38,12 @@ def generation_iterator( for k, v in inputs.items(): inputs[k] = v.cuda() input_ids = inputs["input_ids"] - input_ids_seq_length = input_ids.shape[-1] + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = model.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) - eos_token_id = generation_config.eos_token_id + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if additional_eos_token_id is not None: @@ -58,24 +59,20 @@ def generation_iterator( elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: - logger.warning( - "Both `max_new_tokens` (={%s}) and `max_length`(=" - "{%s}) seem to have been set. `max_new_tokens` will take precedence. " + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - generation_config.max_new_tokens, - generation_config.max_length, + UserWarning, ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "input_ids" logger.warning( - "Input length of {%s} is {%s}, but `max_length` is set to" - " {%s}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`.", - input_ids_string, - input_ids_seq_length, - generation_config.max_length, + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." ) # 2. Set generation parameters if not already defined @@ -114,7 +111,7 @@ def generation_iterator( next_token_scores = logits_warper(input_ids, next_token_scores) # sample - probs = next_token_scores.softmax(dim=-1) + probs = nn.functional.softmax(next_token_scores, dim=-1) if generation_config.do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: @@ -122,9 +119,11 @@ def generation_iterator( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + model_kwargs = model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ) unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) - + output_token_ids = input_ids[0].cpu().tolist() output_token_ids = output_token_ids[input_length:] for each_eos_token_id in eos_token_id: diff --git a/web_demo.py b/web_demo.py index f07bfdd..e9334cd 100644 --- a/web_demo.py +++ b/web_demo.py @@ -8,7 +8,6 @@ Please refer to these links below for more information: import streamlit as st import torch -import torch.nn as nn from dataclasses import dataclass, asdict from typing import List, Optional, Callable, Optional import copy @@ -16,140 +15,16 @@ import warnings import logging from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.utils import logging -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList + +from tools.transformers.interface import generate_interactive, GenerationConfig logger = logging.get_logger(__name__) -@torch.inference_mode() -def generate_interactive( - model, - tokenizer, - prompt, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - additional_eos_token_id: Optional[int] = None, - **kwargs, -): - inputs = tokenizer([prompt], padding=True, return_tensors="pt") - input_length = len(inputs["input_ids"][0]) - for k, v in inputs.items(): - inputs[k] = v.cuda() - input_ids = inputs["input_ids"] - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - if generation_config is None: - generation_config = model.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - if additional_eos_token_id is not None: - eos_token_id.append(additional_eos_token_id) - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = model._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = model._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = model( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = model._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=False - ) - unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) - - output_token_ids = input_ids[0].cpu().tolist() - output_token_ids = output_token_ids[input_length:] - for each_eos_token_id in eos_token_id: - if output_token_ids[-1] == each_eos_token_id: - output_token_ids = output_token_ids[:-1] - response = tokenizer.decode(output_token_ids) - - yield response - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def on_btn_click(): del st.session_state.messages - -@dataclass -class GenerationConfig: - max_length: Optional[int] = None - top_p: Optional[float] = None - temperature: Optional[float] = None - do_sample: Optional[bool] = True - repetition_penalty: Optional[float] = 1.0 - - @st.cache_resource def load_model(): model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()