mirror of https://github.com/InternLM/InternLM
update demo
parent
b49ebba597
commit
c23c80a810
224
chat/web_demo.py
224
chat/web_demo.py
|
@ -14,8 +14,10 @@ Please run with the command `streamlit run path/to/web_demo.py
|
|||
--server.address=0.0.0.0 --server.port 7860`.
|
||||
Using `python path/to/web_demo.py` may cause unknown problems.
|
||||
"""
|
||||
|
||||
# isort: skip_file
|
||||
import copy
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
@ -25,13 +27,13 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.generation.utils import (LogitsProcessorList,
|
||||
StoppingCriteriaList)
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
st.set_page_config(layout='wide')
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -52,8 +54,7 @@ def generate_interactive(
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||
List[int]]] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
additional_eos_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -75,8 +76,7 @@ def generate_interactive(
|
|||
eos_token_id = [eos_token_id]
|
||||
if additional_eos_token_id is not None:
|
||||
eos_token_id.append(additional_eos_token_id)
|
||||
has_default_max_length = kwargs.get(
|
||||
'max_length') is None and generation_config.max_length is not None
|
||||
has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using 'max_length''s default \
|
||||
|
@ -89,8 +89,7 @@ def generate_interactive(
|
|||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + \
|
||||
input_ids_seq_length
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
|
||||
|
@ -108,13 +107,12 @@ def generate_interactive(
|
|||
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
|
||||
f"but 'max_length' is set to {generation_config.max_length}. "
|
||||
'This can lead to unexpected behavior. You should consider'
|
||||
" increasing 'max_new_tokens'.")
|
||||
" increasing 'max_new_tokens'."
|
||||
)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None \
|
||||
else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None \
|
||||
else StoppingCriteriaList()
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
|
@ -125,20 +123,18 @@ def generate_interactive(
|
|||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria)
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
|
||||
if transformers.__version__ >= '4.42.0':
|
||||
logits_warper = model._get_logits_warper(generation_config,
|
||||
device='cuda')
|
||||
logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
||||
else:
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
|
@ -162,10 +158,8 @@ def generate_interactive(
|
|||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
(min(next_tokens != i for i in eos_token_id)).long())
|
||||
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
output_token_ids = input_ids[0].cpu().tolist()
|
||||
output_token_ids = output_token_ids[input_length:]
|
||||
|
@ -177,38 +171,38 @@ def generate_interactive(
|
|||
yield response
|
||||
# stop when each sentence is finished
|
||||
# or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores):
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
|
||||
def on_btn_click():
|
||||
del st.session_state.messages
|
||||
del st.session_state.deepthink_messages
|
||||
del st.session_state.deep_mode
|
||||
|
||||
|
||||
def postprocess(text):
|
||||
text = re.sub(r'\\\(|\\\)', r'$', text)
|
||||
text = re.sub(r'\\\[|\\\]', r'$$', text)
|
||||
return text
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = (AutoModelForCausalLM.from_pretrained(
|
||||
'internlm/internlm2_5-7b-chat',
|
||||
trust_remote_code=True).to(torch.bfloat16).cuda())
|
||||
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2_5-7b-chat',
|
||||
trust_remote_code=True)
|
||||
model_path = 'internlm/internlm2_5-7b-chat'
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_generation_config():
|
||||
with st.sidebar:
|
||||
max_length = st.slider('Max Length',
|
||||
min_value=8,
|
||||
max_value=32768,
|
||||
value=32768)
|
||||
max_length = st.slider('Max Length', min_value=8, max_value=32768, value=32768)
|
||||
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
|
||||
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
|
||||
st.button('Clear Chat History', on_click=on_btn_click)
|
||||
|
||||
generation_config = GenerationConfig(max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
@ -219,11 +213,73 @@ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
|
|||
<|im_start|>assistant\n'
|
||||
|
||||
|
||||
def combine_history(prompt):
|
||||
messages = st.session_state.messages
|
||||
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
|
||||
'and harmless AI assistant developed by Shanghai '
|
||||
'AI Laboratory (上海人工智能实验室).')
|
||||
def combine_history(prompt, deepthink=False, start=0, stop=None):
|
||||
if stop is None:
|
||||
stop = len(st.session_state.messages)
|
||||
elif stop < 0:
|
||||
stop = len(st.session_state.messages) + stop
|
||||
messages = []
|
||||
for idx in range(start, stop):
|
||||
message, deepthink_message = st.session_state.messages[idx], st.session_state.deepthink_messages[idx]
|
||||
if deepthink and deepthink_message['content'] is not None:
|
||||
messages.append(deepthink_message)
|
||||
else:
|
||||
messages.append(message)
|
||||
meta_instruction = (
|
||||
(
|
||||
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||
## Deep Understanding
|
||||
Take time to fully comprehend the problem before attempting a solution. Consider:
|
||||
- What is the real question being asked?
|
||||
- What are the given conditions and what do they tell us?
|
||||
- Are there any special restrictions or assumptions?
|
||||
- Which information is crucial and which is supplementary?
|
||||
## Multi-angle Analysis
|
||||
Before solving, conduct thorough analysis:
|
||||
- What mathematical concepts and properties are involved?
|
||||
- Can you recall similar classic problems or solution methods?
|
||||
- Would diagrams or tables help visualize the problem?
|
||||
- Are there special cases that need separate consideration?
|
||||
## Systematic Thinking
|
||||
Plan your solution path:
|
||||
- Propose multiple possible approaches
|
||||
- Analyze the feasibility and merits of each method
|
||||
- Choose the most appropriate method and explain why
|
||||
- Break complex problems into smaller, manageable steps
|
||||
## Rigorous Proof
|
||||
During the solution process:
|
||||
- Provide solid justification for each step
|
||||
- Include detailed proofs for key conclusions
|
||||
- Pay attention to logical connections
|
||||
- Be vigilant about potential oversights
|
||||
## Repeated Verification
|
||||
After completing your solution:
|
||||
- Verify your results satisfy all conditions
|
||||
- Check for overlooked special cases
|
||||
- Consider if the solution can be optimized or simplified
|
||||
- Review your reasoning process
|
||||
Remember:
|
||||
1. Take time to think thoroughly rather than rushing to an answer
|
||||
2. Rigorously prove each key conclusion
|
||||
3. Keep an open mind and try different approaches
|
||||
4. Summarize valuable problem-solving methods
|
||||
5. Maintain healthy skepticism and verify multiple times
|
||||
Your response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.
|
||||
When you're ready, present your complete solution with:
|
||||
- Clear problem understanding
|
||||
- Detailed solution process
|
||||
- Key insights
|
||||
- Thorough verification
|
||||
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
|
||||
"""
|
||||
)
|
||||
if deepthink
|
||||
else (
|
||||
'You are InternLM (书生·浦语), a helpful, honest, '
|
||||
'and harmless AI assistant developed by Shanghai '
|
||||
'AI Laboratory (上海人工智能实验室).'
|
||||
)
|
||||
)
|
||||
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
|
||||
for message in messages:
|
||||
cur_content = message['content']
|
||||
|
@ -247,50 +303,92 @@ def main():
|
|||
user_avator = 'assets/user.png'
|
||||
robot_avator = 'assets/robot.png'
|
||||
|
||||
st.title('internlm2_5-7b-chat')
|
||||
st.title('internlm3-8b-chat')
|
||||
|
||||
generation_config = prepare_generation_config()
|
||||
|
||||
# Initialize chat history
|
||||
if 'messages' not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if 'deepthink_messages' not in st.session_state:
|
||||
st.session_state.deepthink_messages = []
|
||||
if 'deep_mode' not in st.session_state:
|
||||
st.session_state.deep_mode = {}
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
for idx, (message, deepthink_message) in enumerate(
|
||||
zip(st.session_state.messages, st.session_state.deepthink_messages)
|
||||
):
|
||||
with st.chat_message(message['role'], avatar=message.get('avatar')):
|
||||
st.markdown(message['content'])
|
||||
if message['role'] == 'user':
|
||||
st.markdown(postprocess(message['content']))
|
||||
else:
|
||||
if st.button('深度思考', key=f'deep_mode_{idx}'):
|
||||
st.session_state.deep_mode[idx] = not st.session_state.deep_mode.get(idx, False)
|
||||
if st.session_state.deep_mode.get(idx, False):
|
||||
cols = st.columns(2)
|
||||
with cols[0]:
|
||||
st.markdown(postprocess(message['content']))
|
||||
with cols[1]:
|
||||
if deepthink_message['content'] is None:
|
||||
real_prompt = combine_history(
|
||||
st.session_state.deepthink_messages[idx - 1]['content'], deepthink=True, stop=idx - 1
|
||||
)
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
||||
message_placeholder.markdown(postprocess(cur_response))
|
||||
deepthink_message['content'] = cur_response
|
||||
else:
|
||||
st.markdown(postprocess(deepthink_message['content']))
|
||||
else:
|
||||
st.markdown(postprocess(message['content']))
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input('What is up?'):
|
||||
# Display user message in chat message container
|
||||
with st.chat_message('user', avatar=user_avator):
|
||||
st.markdown(prompt)
|
||||
st.markdown(postprocess(prompt))
|
||||
real_prompt = combine_history(prompt)
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({
|
||||
'role': 'user',
|
||||
'content': prompt,
|
||||
'avatar': user_avator
|
||||
})
|
||||
st.session_state.messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
||||
st.session_state.deepthink_messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
||||
|
||||
with st.chat_message('robot', avatar=robot_avator):
|
||||
st.button('深度思考', key=f'deep_mode_{len(st.session_state.messages)}')
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(cur_response + '▌')
|
||||
message_placeholder.markdown(cur_response)
|
||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
||||
message_placeholder.markdown(postprocess(cur_response))
|
||||
# Add robot response to chat history
|
||||
st.session_state.messages.append({
|
||||
'role': 'robot',
|
||||
'content': cur_response, # pylint: disable=undefined-loop-variable
|
||||
'avatar': robot_avator,
|
||||
})
|
||||
st.session_state.messages.append(
|
||||
{
|
||||
'role': 'robot',
|
||||
'content': cur_response, # pylint: disable=undefined-loop-variable
|
||||
'avatar': robot_avator,
|
||||
}
|
||||
)
|
||||
st.session_state.deepthink_messages.append(
|
||||
{
|
||||
'role': 'robot',
|
||||
'content': None,
|
||||
'avatar': robot_avator,
|
||||
}
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue