pull/820/head
braisedpork1964 2025-01-15 15:23:58 +00:00
parent bd4d32dba7
commit 644aa59dae
5 changed files with 11 additions and 31 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

@ -81,4 +81,4 @@ streamlit run ./chat/web_demo.py
It supports switching between different inference modes and comparing their responses. It supports switching between different inference modes and comparing their responses.
![demo](../assets/web_demo.png) ![demo](https://github.com/user-attachments/assets/4953befa-343f-499d-b289-048d982439f3)

View File

@ -79,4 +79,4 @@ streamlit run ./web_demo.py
支持切换不同推理模式,并比较它们的回复 支持切换不同推理模式,并比较它们的回复
![demo](../assets/web_demo_zh_cn.png) ![demo](https://github.com/user-attachments/assets/952e250d-22a6-4544-b8e3-9c21c746d3c7)

View File

@ -26,8 +26,7 @@ import streamlit as st
import torch import torch
from torch import nn from torch import nn
import transformers from transformers.generation.utils import LogitsProcessorList
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
@ -53,7 +52,6 @@ def generate_interactive(
prompt, prompt,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None, List[int]]] = None,
additional_eos_token_id: Optional[int] = None, additional_eos_token_id: Optional[int] = None,
@ -72,10 +70,7 @@ def generate_interactive(
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
if generation_config.temperature == 0.0: if generation_config.temperature == 0.0:
generation_config.do_sample = False generation_config.do_sample = False
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 eos_token_id = generation_config.eos_token_id
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
if additional_eos_token_id is not None: if additional_eos_token_id is not None:
@ -94,7 +89,8 @@ def generate_interactive(
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length generation_config.max_length = (generation_config.max_new_tokens +
input_ids_seq_length)
if not has_default_max_length: if not has_default_max_length:
logger.warn( # pylint: disable=W4902 logger.warn( # pylint: disable=W4902
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
@ -115,10 +111,8 @@ def generate_interactive(
" increasing 'max_new_tokens'.") " increasing 'max_new_tokens'.")
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( logits_processor = (logits_processor if logits_processor is not None else
) LogitsProcessorList())
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor( logits_processor = model._get_logits_processor(
generation_config=generation_config, generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length, input_ids_seq_length=input_ids_seq_length,
@ -126,18 +120,7 @@ def generate_interactive(
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
) )
# stopping_criteria = model._get_stopping_criteria(
# generation_config=generation_config, stopping_criteria=stopping_criteria
# )
# if transformers.__version__ >= '4.42.0':
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
# else:
# logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True: while True:
model_inputs = model.prepare_inputs_for_generation( model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs) input_ids, **model_kwargs)
@ -153,7 +136,6 @@ def generate_interactive(
# pre-process distribution # pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_processor(input_ids, next_token_logits)
# next_token_scores = logits_warper(input_ids, next_token_scores)
# sample # sample
probs = nn.functional.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -164,7 +146,6 @@ def generate_interactive(
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long()) (min(next_tokens != i for i in eos_token_id)).long())
@ -178,8 +159,6 @@ def generate_interactive(
yield response yield response
# stop when each sentence is finished # stop when each sentence is finished
# or if we exceed the maximum length # or if we exceed the maximum length
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
# break
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
break break
@ -259,7 +238,8 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
'and harmless AI assistant developed by Shanghai ' 'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).') 'AI Laboratory (上海人工智能实验室).')
if deepthink: if deepthink:
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes: meta_instruction += (
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
## Deep Understanding ## Deep Understanding
Take time to fully comprehend the problem before attempting a solution. Consider: Take time to fully comprehend the problem before attempting a solution. Consider:
- What is the real question being asked? - What is the real question being asked?
@ -303,7 +283,7 @@ When you're ready, present your complete solution with:
- Key insights - Key insights
- Thorough verification - Thorough verification
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer. Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
""" """) # noqa: E501
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n' total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
for message in messages: for message in messages:
cur_content = message['content'] cur_content = message['content']