mirror of https://github.com/InternLM/InternLM
				
				
				
			[Fix]: Update web demo to be self-contained (#624)
							parent
							
								
									519c7934c4
								
							
						
					
					
						commit
						f08a18b9b7
					
				
							
								
								
									
										149
									
								
								chat/web_demo.py
								
								
								
								
							
							
						
						
									
										149
									
								
								chat/web_demo.py
								
								
								
								
							|  | @ -7,17 +7,148 @@ Please refer to these links below for more information: | |||
|     3. transformers: https://github.com/huggingface/transformers | ||||
| """ | ||||
| 
 | ||||
| from dataclasses import asdict | ||||
| import copy | ||||
| import warnings | ||||
| from dataclasses import asdict, dataclass | ||||
| from typing import Callable, List, Optional | ||||
| 
 | ||||
| import streamlit as st | ||||
| import torch | ||||
| from tools.transformers.interface import GenerationConfig, generate_interactive | ||||
| from torch import nn | ||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||
| from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList | ||||
| from transformers.utils import logging | ||||
| 
 | ||||
| logger = logging.get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class GenerationConfig: | ||||
|     # this config is used for chat to provide more diversity | ||||
|     max_length: int = 32768 | ||||
|     top_p: float = 0.8 | ||||
|     temperature: float = 0.8 | ||||
|     do_sample: bool = True | ||||
|     repetition_penalty: float = 1.005 | ||||
| 
 | ||||
| 
 | ||||
| @torch.inference_mode() | ||||
| def generate_interactive( | ||||
|     model, | ||||
|     tokenizer, | ||||
|     prompt, | ||||
|     generation_config: Optional[GenerationConfig] = None, | ||||
|     logits_processor: Optional[LogitsProcessorList] = None, | ||||
|     stopping_criteria: Optional[StoppingCriteriaList] = None, | ||||
|     prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, | ||||
|     additional_eos_token_id: Optional[int] = None, | ||||
|     **kwargs, | ||||
| ): | ||||
|     inputs = tokenizer([prompt], padding=True, return_tensors="pt") | ||||
|     input_length = len(inputs["input_ids"][0]) | ||||
|     for k, v in inputs.items(): | ||||
|         inputs[k] = v.cuda() | ||||
|     input_ids = inputs["input_ids"] | ||||
|     batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]  # noqa: F841  # pylint: disable=W0612 | ||||
|     if generation_config is None: | ||||
|         generation_config = model.generation_config | ||||
|     generation_config = copy.deepcopy(generation_config) | ||||
|     model_kwargs = generation_config.update(**kwargs) | ||||
|     bos_token_id, eos_token_id = (  # noqa: F841  # pylint: disable=W0612 | ||||
|         generation_config.bos_token_id, | ||||
|         generation_config.eos_token_id, | ||||
|     ) | ||||
|     if isinstance(eos_token_id, int): | ||||
|         eos_token_id = [eos_token_id] | ||||
|     if additional_eos_token_id is not None: | ||||
|         eos_token_id.append(additional_eos_token_id) | ||||
|     has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None | ||||
|     if has_default_max_length and generation_config.max_new_tokens is None: | ||||
|         warnings.warn( | ||||
|             f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " | ||||
|             "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" | ||||
|             " recommend using `max_new_tokens` to control the maximum length of the generation.", | ||||
|             UserWarning, | ||||
|         ) | ||||
|     elif generation_config.max_new_tokens is not None: | ||||
|         generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length | ||||
|         if not has_default_max_length: | ||||
|             logger.warn(  # pylint: disable=W4902 | ||||
|                 f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" | ||||
|                 f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " | ||||
|                 "Please refer to the documentation for more information. " | ||||
|                 "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", | ||||
|                 UserWarning, | ||||
|             ) | ||||
| 
 | ||||
|     if input_ids_seq_length >= generation_config.max_length: | ||||
|         input_ids_string = "input_ids" | ||||
|         logger.warning( | ||||
|             f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" | ||||
|             f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" | ||||
|             " increasing `max_new_tokens`." | ||||
|         ) | ||||
| 
 | ||||
|     # 2. Set generation parameters if not already defined | ||||
|     logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||||
|     stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||||
| 
 | ||||
|     logits_processor = model._get_logits_processor( | ||||
|         generation_config=generation_config, | ||||
|         input_ids_seq_length=input_ids_seq_length, | ||||
|         encoder_input_ids=input_ids, | ||||
|         prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | ||||
|         logits_processor=logits_processor, | ||||
|     ) | ||||
| 
 | ||||
|     stopping_criteria = model._get_stopping_criteria( | ||||
|         generation_config=generation_config, stopping_criteria=stopping_criteria | ||||
|     ) | ||||
|     logits_warper = model._get_logits_warper(generation_config) | ||||
| 
 | ||||
|     unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | ||||
|     scores = None | ||||
|     while True: | ||||
|         model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) | ||||
|         # forward pass to get next token | ||||
|         outputs = model( | ||||
|             **model_inputs, | ||||
|             return_dict=True, | ||||
|             output_attentions=False, | ||||
|             output_hidden_states=False, | ||||
|         ) | ||||
| 
 | ||||
|         next_token_logits = outputs.logits[:, -1, :] | ||||
| 
 | ||||
|         # pre-process distribution | ||||
|         next_token_scores = logits_processor(input_ids, next_token_logits) | ||||
|         next_token_scores = logits_warper(input_ids, next_token_scores) | ||||
| 
 | ||||
|         # sample | ||||
|         probs = nn.functional.softmax(next_token_scores, dim=-1) | ||||
|         if generation_config.do_sample: | ||||
|             next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | ||||
|         else: | ||||
|             next_tokens = torch.argmax(probs, dim=-1) | ||||
| 
 | ||||
|         # update generated ids, model inputs, and length for next step | ||||
|         input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | ||||
|         model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) | ||||
|         unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) | ||||
| 
 | ||||
|         output_token_ids = input_ids[0].cpu().tolist() | ||||
|         output_token_ids = output_token_ids[input_length:] | ||||
|         for each_eos_token_id in eos_token_id: | ||||
|             if output_token_ids[-1] == each_eos_token_id: | ||||
|                 output_token_ids = output_token_ids[:-1] | ||||
|         response = tokenizer.decode(output_token_ids) | ||||
| 
 | ||||
|         yield response | ||||
|         # stop when each sentence is finished, or if we exceed the maximum length | ||||
|         if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | ||||
|             break | ||||
| 
 | ||||
| 
 | ||||
| def on_btn_click(): | ||||
|     del st.session_state.messages | ||||
| 
 | ||||
|  | @ -35,7 +166,7 @@ def load_model(): | |||
| 
 | ||||
| def prepare_generation_config(): | ||||
|     with st.sidebar: | ||||
|         max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048) | ||||
|         max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768) | ||||
|         top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01) | ||||
|         temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01) | ||||
|         st.button("Clear Chat History", on_click=on_btn_click) | ||||
|  | @ -52,17 +183,21 @@ cur_query_prompt = "[UNUSED_TOKEN_146]user\n{user}[UNUSED_TOKEN_145]\n[UNUSED_TO | |||
| 
 | ||||
| def combine_history(prompt): | ||||
|     messages = st.session_state.messages | ||||
|     total_prompt = "" | ||||
|     meta_instruction = ( | ||||
|         "You are InternLM (书生·浦语), a helpful, honest, and harmless AI assistant developed by Shanghai " | ||||
|         "AI Laboratory (上海人工智能实验室)." | ||||
|     ) | ||||
|     total_prompt = f"<s>[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n" | ||||
|     for message in messages: | ||||
|         cur_content = message["content"] | ||||
|         if message["role"] == "user": | ||||
|             cur_prompt = user_prompt.replace("{user}", cur_content) | ||||
|             cur_prompt = user_prompt.format(user=cur_content) | ||||
|         elif message["role"] == "robot": | ||||
|             cur_prompt = robot_prompt.replace("{robot}", cur_content) | ||||
|             cur_prompt = robot_prompt.format(robot=cur_content) | ||||
|         else: | ||||
|             raise RuntimeError | ||||
|         total_prompt += cur_prompt | ||||
|     total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt) | ||||
|     total_prompt = total_prompt + cur_query_prompt.format(user=prompt) | ||||
|     return total_prompt | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Wenwei Zhang
						Wenwei Zhang