mirror of https://github.com/InternLM/InternLM
update
parent
bd4d32dba7
commit
644aa59dae
Binary file not shown.
Before Width: | Height: | Size: 248 KiB |
Binary file not shown.
Before Width: | Height: | Size: 300 KiB |
|
@ -81,4 +81,4 @@ streamlit run ./chat/web_demo.py
|
||||||
|
|
||||||
It supports switching between different inference modes and comparing their responses.
|
It supports switching between different inference modes and comparing their responses.
|
||||||
|
|
||||||

|

|
||||||
|
|
|
@ -79,4 +79,4 @@ streamlit run ./web_demo.py
|
||||||
|
|
||||||
支持切换不同推理模式,并比较它们的回复
|
支持切换不同推理模式,并比较它们的回复
|
||||||
|
|
||||||

|

|
||||||
|
|
|
@ -26,8 +26,7 @@ import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import transformers
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||||
|
@ -53,7 +52,6 @@ def generate_interactive(
|
||||||
prompt,
|
prompt,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||||
List[int]]] = None,
|
List[int]]] = None,
|
||||||
additional_eos_token_id: Optional[int] = None,
|
additional_eos_token_id: Optional[int] = None,
|
||||||
|
@ -72,10 +70,7 @@ def generate_interactive(
|
||||||
model_kwargs = generation_config.update(**kwargs)
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
if generation_config.temperature == 0.0:
|
if generation_config.temperature == 0.0:
|
||||||
generation_config.do_sample = False
|
generation_config.do_sample = False
|
||||||
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
eos_token_id = generation_config.eos_token_id
|
||||||
generation_config.bos_token_id,
|
|
||||||
generation_config.eos_token_id,
|
|
||||||
)
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
if additional_eos_token_id is not None:
|
if additional_eos_token_id is not None:
|
||||||
|
@ -94,7 +89,8 @@ def generate_interactive(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = (generation_config.max_new_tokens +
|
||||||
|
input_ids_seq_length)
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warn( # pylint: disable=W4902
|
logger.warn( # pylint: disable=W4902
|
||||||
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
|
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
|
||||||
|
@ -115,10 +111,8 @@ def generate_interactive(
|
||||||
" increasing 'max_new_tokens'.")
|
" increasing 'max_new_tokens'.")
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
|
logits_processor = (logits_processor if logits_processor is not None else
|
||||||
)
|
LogitsProcessorList())
|
||||||
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
||||||
|
|
||||||
logits_processor = model._get_logits_processor(
|
logits_processor = model._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=input_ids_seq_length,
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
|
@ -126,18 +120,7 @@ def generate_interactive(
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
# stopping_criteria = model._get_stopping_criteria(
|
|
||||||
# generation_config=generation_config, stopping_criteria=stopping_criteria
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if transformers.__version__ >= '4.42.0':
|
|
||||||
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
|
||||||
# else:
|
|
||||||
# logits_warper = model._get_logits_warper(generation_config)
|
|
||||||
|
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
scores = None
|
|
||||||
while True:
|
while True:
|
||||||
model_inputs = model.prepare_inputs_for_generation(
|
model_inputs = model.prepare_inputs_for_generation(
|
||||||
input_ids, **model_kwargs)
|
input_ids, **model_kwargs)
|
||||||
|
@ -153,7 +136,6 @@ def generate_interactive(
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
# next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
@ -164,7 +146,6 @@ def generate_interactive(
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
(min(next_tokens != i for i in eos_token_id)).long())
|
(min(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
|
@ -178,8 +159,6 @@ def generate_interactive(
|
||||||
yield response
|
yield response
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
# or if we exceed the maximum length
|
# or if we exceed the maximum length
|
||||||
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
|
||||||
# break
|
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -259,7 +238,8 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
|
||||||
'and harmless AI assistant developed by Shanghai '
|
'and harmless AI assistant developed by Shanghai '
|
||||||
'AI Laboratory (上海人工智能实验室).')
|
'AI Laboratory (上海人工智能实验室).')
|
||||||
if deepthink:
|
if deepthink:
|
||||||
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
meta_instruction += (
|
||||||
|
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||||
## Deep Understanding
|
## Deep Understanding
|
||||||
Take time to fully comprehend the problem before attempting a solution. Consider:
|
Take time to fully comprehend the problem before attempting a solution. Consider:
|
||||||
- What is the real question being asked?
|
- What is the real question being asked?
|
||||||
|
@ -303,7 +283,7 @@ When you're ready, present your complete solution with:
|
||||||
- Key insights
|
- Key insights
|
||||||
- Thorough verification
|
- Thorough verification
|
||||||
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
|
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
|
||||||
"""
|
""") # noqa: E501
|
||||||
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
|
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
|
||||||
for message in messages:
|
for message in messages:
|
||||||
cur_content = message['content']
|
cur_content = message['content']
|
||||||
|
|
Loading…
Reference in New Issue