diff --git a/assets/web_demo.png b/assets/web_demo.png deleted file mode 100644 index f8c0e6f..0000000 Binary files a/assets/web_demo.png and /dev/null differ diff --git a/assets/web_demo_zh_cn.png b/assets/web_demo_zh_cn.png deleted file mode 100644 index f35cdcd..0000000 Binary files a/assets/web_demo_zh_cn.png and /dev/null differ diff --git a/chat/README.md b/chat/README.md index b21b22e..55d69ee 100644 --- a/chat/README.md +++ b/chat/README.md @@ -81,4 +81,4 @@ streamlit run ./chat/web_demo.py It supports switching between different inference modes and comparing their responses. -![demo](../assets/web_demo.png) +![demo](https://github.com/user-attachments/assets/4953befa-343f-499d-b289-048d982439f3) diff --git a/chat/README_zh-CN.md b/chat/README_zh-CN.md index 8fd4bb2..0408399 100644 --- a/chat/README_zh-CN.md +++ b/chat/README_zh-CN.md @@ -79,4 +79,4 @@ streamlit run ./web_demo.py 支持切换不同推理模式,并比较它们的回复 -![demo](../assets/web_demo_zh_cn.png) +![demo](https://github.com/user-attachments/assets/952e250d-22a6-4544-b8e3-9c21c746d3c7) diff --git a/chat/web_demo.py b/chat/web_demo.py index 0470d40..1fa66b3 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -26,8 +26,7 @@ import streamlit as st import torch from torch import nn -import transformers -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +from transformers.generation.utils import LogitsProcessorList from transformers.utils import logging from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip @@ -53,7 +52,6 @@ def generate_interactive( prompt, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, additional_eos_token_id: Optional[int] = None, @@ -72,10 +70,7 @@ def generate_interactive( model_kwargs = generation_config.update(**kwargs) if generation_config.temperature == 0.0: generation_config.do_sample = False - bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 - generation_config.bos_token_id, - generation_config.eos_token_id, - ) + eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if additional_eos_token_id is not None: @@ -94,7 +89,8 @@ def generate_interactive( UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + generation_config.max_length = (generation_config.max_new_tokens + + input_ids_seq_length) if not has_default_max_length: logger.warn( # pylint: disable=W4902 f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " @@ -115,10 +111,8 @@ def generate_interactive( " increasing 'max_new_tokens'.") # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( - ) - # stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - + logits_processor = (logits_processor if logits_processor is not None else + LogitsProcessorList()) logits_processor = model._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, @@ -126,18 +120,7 @@ def generate_interactive( prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) - - # stopping_criteria = model._get_stopping_criteria( - # generation_config=generation_config, stopping_criteria=stopping_criteria - # ) - - # if transformers.__version__ >= '4.42.0': - # logits_warper = model._get_logits_warper(generation_config, device='cuda') - # else: - # logits_warper = model._get_logits_warper(generation_config) - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None while True: model_inputs = model.prepare_inputs_for_generation( input_ids, **model_kwargs) @@ -153,7 +136,6 @@ def generate_interactive( # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - # next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -164,7 +146,6 @@ def generate_interactive( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - # model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) unfinished_sequences = unfinished_sequences.mul( (min(next_tokens != i for i in eos_token_id)).long()) @@ -178,8 +159,6 @@ def generate_interactive( yield response # stop when each sentence is finished # or if we exceed the maximum length - # if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - # break if unfinished_sequences.max() == 0: break @@ -259,7 +238,8 @@ def combine_history(prompt, deepthink=False, start=0, stop=None): 'and harmless AI assistant developed by Shanghai ' 'AI Laboratory (上海人工智能实验室).') if deepthink: - meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes: + meta_instruction += ( + """You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes: ## Deep Understanding Take time to fully comprehend the problem before attempting a solution. Consider: - What is the real question being asked? @@ -303,7 +283,7 @@ When you're ready, present your complete solution with: - Key insights - Thorough verification Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer. -""" +""") # noqa: E501 total_prompt = f'<|im_start|>system\n{meta_instruction}<|im_end|>\n' for message in messages: cur_content = message['content']