pull/820/head
braisedpork1964 2025-01-15 15:23:58 +00:00
parent bd4d32dba7
commit 644aa59dae
5 changed files with 11 additions and 31 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

@ -81,4 +81,4 @@ streamlit run ./chat/web_demo.py
It supports switching between different inference modes and comparing their responses.
![demo](../assets/web_demo.png)
![demo](https://github.com/user-attachments/assets/4953befa-343f-499d-b289-048d982439f3)

View File

@ -79,4 +79,4 @@ streamlit run ./web_demo.py
支持切换不同推理模式,并比较它们的回复
![demo](../assets/web_demo_zh_cn.png)
![demo](https://github.com/user-attachments/assets/952e250d-22a6-4544-b8e3-9c21c746d3c7)

View File

@ -26,8 +26,7 @@ import streamlit as st
import torch
from torch import nn
import transformers
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import LogitsProcessorList
from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
@ -53,7 +52,6 @@ def generate_interactive(
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
@ -72,10 +70,7 @@ def generate_interactive(
model_kwargs = generation_config.update(**kwargs)
if generation_config.temperature == 0.0:
generation_config.do_sample = False
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
@ -94,7 +89,8 @@ def generate_interactive(
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
generation_config.max_length = (generation_config.max_new_tokens +
input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
@ -115,10 +111,8 @@ def generate_interactive(
" increasing 'max_new_tokens'.")
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
)
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = (logits_processor if logits_processor is not None else
LogitsProcessorList())
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
@ -126,18 +120,7 @@ def generate_interactive(
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
# stopping_criteria = model._get_stopping_criteria(
# generation_config=generation_config, stopping_criteria=stopping_criteria
# )
# if transformers.__version__ >= '4.42.0':
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
# else:
# logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
@ -153,7 +136,6 @@ def generate_interactive(
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -164,7 +146,6 @@ def generate_interactive(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())
@ -178,8 +159,6 @@ def generate_interactive(
yield response
# stop when each sentence is finished
# or if we exceed the maximum length
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
# break
if unfinished_sequences.max() == 0:
break
@ -259,7 +238,8 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).')
if deepthink:
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
meta_instruction += (
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
## Deep Understanding
Take time to fully comprehend the problem before attempting a solution. Consider:
- What is the real question being asked?
@ -303,7 +283,7 @@ When you're ready, present your complete solution with:
- Key insights
- Thorough verification
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
"""
""") # noqa: E501
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
for message in messages:
cur_content = message['content']