mirror of https://github.com/InternLM/InternLM
update
parent
bd4d32dba7
commit
644aa59dae
Binary file not shown.
Before Width: | Height: | Size: 248 KiB |
Binary file not shown.
Before Width: | Height: | Size: 300 KiB |
|
@ -81,4 +81,4 @@ streamlit run ./chat/web_demo.py
|
|||
|
||||
It supports switching between different inference modes and comparing their responses.
|
||||
|
||||

|
||||

|
||||
|
|
|
@ -79,4 +79,4 @@ streamlit run ./web_demo.py
|
|||
|
||||
支持切换不同推理模式,并比较它们的回复
|
||||
|
||||

|
||||

|
||||
|
|
|
@ -26,8 +26,7 @@ import streamlit as st
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||
|
@ -53,7 +52,6 @@ def generate_interactive(
|
|||
prompt,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||
List[int]]] = None,
|
||||
additional_eos_token_id: Optional[int] = None,
|
||||
|
@ -72,10 +70,7 @@ def generate_interactive(
|
|||
model_kwargs = generation_config.update(**kwargs)
|
||||
if generation_config.temperature == 0.0:
|
||||
generation_config.do_sample = False
|
||||
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if additional_eos_token_id is not None:
|
||||
|
@ -94,7 +89,8 @@ def generate_interactive(
|
|||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
generation_config.max_length = (generation_config.max_new_tokens +
|
||||
input_ids_seq_length)
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
|
||||
|
@ -115,10 +111,8 @@ def generate_interactive(
|
|||
" increasing 'max_new_tokens'.")
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
|
||||
)
|
||||
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = (logits_processor if logits_processor is not None else
|
||||
LogitsProcessorList())
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
|
@ -126,18 +120,7 @@ def generate_interactive(
|
|||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
# stopping_criteria = model._get_stopping_criteria(
|
||||
# generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
# )
|
||||
|
||||
# if transformers.__version__ >= '4.42.0':
|
||||
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
||||
# else:
|
||||
# logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
|
@ -153,7 +136,6 @@ def generate_interactive(
|
|||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
# next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
@ -164,7 +146,6 @@ def generate_interactive(
|
|||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
(min(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
|
@ -178,8 +159,6 @@ def generate_interactive(
|
|||
yield response
|
||||
# stop when each sentence is finished
|
||||
# or if we exceed the maximum length
|
||||
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
# break
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
|
@ -259,7 +238,8 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
|
|||
'and harmless AI assistant developed by Shanghai '
|
||||
'AI Laboratory (上海人工智能实验室).')
|
||||
if deepthink:
|
||||
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||
meta_instruction += (
|
||||
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||
## Deep Understanding
|
||||
Take time to fully comprehend the problem before attempting a solution. Consider:
|
||||
- What is the real question being asked?
|
||||
|
@ -303,7 +283,7 @@ When you're ready, present your complete solution with:
|
|||
- Key insights
|
||||
- Thorough verification
|
||||
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
|
||||
"""
|
||||
""") # noqa: E501
|
||||
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
|
||||
for message in messages:
|
||||
cur_content = message['content']
|
||||
|
|
Loading…
Reference in New Issue