[Feat] Add deep thinking demo (#820)

pull/824/head
BraisedPork 2025-01-16 00:02:33 +08:00 committed by GitHub
parent b49ebba597
commit 051011405f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 196 additions and 52 deletions

View File

@ -2,7 +2,7 @@
English | [简体中文](./README_zh-CN.md)
This document briefly shows how to use [Transformers](#import-from-transformers), [ModelScope](#import-from-modelscope), and [Web demos](#dialogue) to conduct inference with InternLM2.5-Chat.
This document briefly shows how to use [Transformers](#import-from-transformers), [ModelScope](#import-from-modelscope), and [Web demos](#dialogue) to conduct inference with InternLM3-Instruct.
You can also know more about the [chatml format](./chat_format.md) and how to use [LMDeploy for inference and model serving](./lmdeploy.md).
@ -79,6 +79,6 @@ pip install transformers>=4.48
streamlit run ./chat/web_demo.py
```
The effect is similar to below:
It supports switching between different inference modes and comparing their responses.
![demo](https://github.com/InternLM/InternLM/assets/9102141/11b60ee0-47e4-42c0-8278-3051b2f17fe4)
![demo](https://github.com/user-attachments/assets/4953befa-343f-499d-b289-048d982439f3)

View File

@ -3,9 +3,9 @@
[English](./README.md) | 简体中文
本文介绍采用 [Transformers](#import-from-transformers)、[ModelScope](#import-from-modelscope)、[Web demos](#dialogue)
对 InternLM2.5-Chat 进行推理。
对 InternLM3-Instruct 进行推理。
你还可以进一步了解 InternLM2.5-Chat 采用的[对话格式](./chat_format_zh-CN.md),以及如何[用 LMDeploy 进行推理或部署服务](./lmdeploy_zh-CN.md),或者尝试用 [OpenAOE](./openaoe.md) 与多个模型对话。
你还可以进一步了解 InternLM3-Instruct 采用的[对话格式](./chat_format_zh-CN.md),以及如何[用 LMDeploy 进行推理或部署服务](./lmdeploy_zh-CN.md),或者尝试用 [OpenAOE](./openaoe.md) 与多个模型对话。
## 通过 Transformers 加载
@ -39,7 +39,7 @@ response = tokenizer.batch_decode(generated_ids)[0]
### 通过 ModelScope 加载
通过以下的代码从 ModelScope 加载 InternLM2.5-Chat 模型 (可修改模型名称替换不同的模型)
通过以下的代码从 ModelScope 加载 InternLM3-Instruct 模型 (可修改模型名称替换不同的模型)
```python
import torch
@ -76,3 +76,7 @@ pip install streamlit
pip install transformers>=4.48
streamlit run ./web_demo.py
```
支持切换不同推理模式,并比较它们的回复
![demo](https://github.com/user-attachments/assets/952e250d-22a6-4544-b8e3-9c21c746d3c7)

View File

@ -14,8 +14,10 @@ Please run with the command `streamlit run path/to/web_demo.py
--server.address=0.0.0.0 --server.port 7860`.
Using `python path/to/web_demo.py` may cause unknown problems.
"""
# isort: skip_file
import copy
import re
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
@ -24,14 +26,13 @@ import streamlit as st
import torch
from torch import nn
import transformers
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from transformers.generation.utils import LogitsProcessorList
from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
logger = logging.get_logger(__name__)
st.set_page_config(layout='wide')
@dataclass
@ -51,7 +52,6 @@ def generate_interactive(
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
@ -66,11 +66,11 @@ def generate_interactive(
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
generation_config._eos_token_tensor = generation_config.eos_token_id
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if generation_config.temperature == 0.0:
generation_config.do_sample = False
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
@ -89,8 +89,8 @@ def generate_interactive(
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + \
input_ids_seq_length
generation_config.max_length = (generation_config.max_new_tokens +
input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
@ -111,11 +111,8 @@ def generate_interactive(
" increasing 'max_new_tokens'.")
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None \
else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None \
else StoppingCriteriaList()
logits_processor = (logits_processor if logits_processor is not None else
LogitsProcessorList())
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
@ -123,19 +120,7 @@ def generate_interactive(
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
if transformers.__version__ >= '4.42.0':
logits_warper = model._get_logits_warper(generation_config,
device='cuda')
else:
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
@ -151,7 +136,6 @@ def generate_interactive(
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -162,8 +146,6 @@ def generate_interactive(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())
@ -177,21 +159,31 @@ def generate_interactive(
yield response
# stop when each sentence is finished
# or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores):
if unfinished_sequences.max() == 0:
break
def on_btn_click():
del st.session_state.messages
del st.session_state.deepthink_messages
def postprocess(text, add_prefix=True, deepthink=False):
text = re.sub(r'\\\(|\\\)', r'$', text)
text = re.sub(r'\\\[|\\\]', r'$$', text)
if add_prefix:
text = (':red[[Deep Thinking]]\n\n'
if deepthink else ':blue[[Normal Response]]\n\n') + text
return text
@st.cache_resource
def load_model():
model = (AutoModelForCausalLM.from_pretrained(
'internlm/internlm2_5-7b-chat',
trust_remote_code=True).to(torch.bfloat16).cuda())
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2_5-7b-chat',
model_path = 'internlm/internlm3-8b-instruct'
model = AutoModelForCausalLM.from_pretrained(model_path,
trust_remote_code=True).to(
torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)
return model, tokenizer
@ -204,8 +196,12 @@ def prepare_generation_config():
value=32768)
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
radio = st.radio('Inference Mode',
['Normal Response', 'Deep Thinking'],
key='mode')
st.button('Clear Chat History', on_click=on_btn_click)
st.session_state['inference_mode'] = radio
generation_config = GenerationConfig(max_length=max_length,
top_p=top_p,
temperature=temperature)
@ -219,11 +215,75 @@ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
<|im_start|>assistant\n'
def combine_history(prompt):
messages = st.session_state.messages
def combine_history(prompt, deepthink=False, start=0, stop=None):
if stop is None:
stop = len(st.session_state.messages)
elif stop < 0:
stop = len(st.session_state.messages) + stop
messages = []
for idx in range(start, stop):
message, deepthink_message = st.session_state.messages[
idx], st.session_state.deepthink_messages[idx]
if deepthink:
if deepthink_message['content'] is not None:
messages.append(deepthink_message)
else:
messages.append(message)
else:
if message['content'] is not None:
messages.append(message)
else:
messages.append(deepthink_message)
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).')
if deepthink:
meta_instruction += (
"""You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
## Deep Understanding
Take time to fully comprehend the problem before attempting a solution. Consider:
- What is the real question being asked?
- What are the given conditions and what do they tell us?
- Are there any special restrictions or assumptions?
- Which information is crucial and which is supplementary?
## Multi-angle Analysis
Before solving, conduct thorough analysis:
- What mathematical concepts and properties are involved?
- Can you recall similar classic problems or solution methods?
- Would diagrams or tables help visualize the problem?
- Are there special cases that need separate consideration?
## Systematic Thinking
Plan your solution path:
- Propose multiple possible approaches
- Analyze the feasibility and merits of each method
- Choose the most appropriate method and explain why
- Break complex problems into smaller, manageable steps
## Rigorous Proof
During the solution process:
- Provide solid justification for each step
- Include detailed proofs for key conclusions
- Pay attention to logical connections
- Be vigilant about potential oversights
## Repeated Verification
After completing your solution:
- Verify your results satisfy all conditions
- Check for overlooked special cases
- Consider if the solution can be optimized or simplified
- Review your reasoning process
Remember:
1. Take time to think thoroughly rather than rushing to an answer
2. Rigorously prove each key conclusion
3. Keep an open mind and try different approaches
4. Summarize valuable problem-solving methods
5. Maintain healthy skepticism and verify multiple times
Your response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.
When you're ready, present your complete solution with:
- Clear problem understanding
- Detailed solution process
- Key insights
- Thorough verification
Focus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.
""") # noqa: E501
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
for message in messages:
cur_content = message['content']
@ -247,33 +307,98 @@ def main():
user_avator = 'assets/user.png'
robot_avator = 'assets/robot.png'
st.title('internlm2_5-7b-chat')
st.title('InternLM3-8B-Instruct')
generation_config = prepare_generation_config()
def render_message(msg, msg_idx, deepthink):
if msg['content'] is None:
real_prompt = combine_history(
st.session_state.messages[msg_idx - 1]['content'],
deepthink=deepthink,
stop=msg_idx - 1)
placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
placeholder.markdown(
postprocess(cur_response, deepthink=deepthink) + '')
placeholder.markdown(postprocess(cur_response,
deepthink=deepthink))
msg['content'] = cur_response
torch.cuda.empty_cache()
else:
st.markdown(postprocess(msg['content'], deepthink=deepthink))
# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'deepthink_messages' not in st.session_state:
st.session_state.deepthink_messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
for idx, (message, deepthink_message) in enumerate(
zip(st.session_state.messages,
st.session_state.deepthink_messages)):
with st.chat_message(message['role'], avatar=message.get('avatar')):
st.markdown(message['content'])
if message['role'] == 'user':
st.markdown(postprocess(message['content'], add_prefix=False))
else:
if st.toggle('compare', key=f'compare_{idx}'):
cols = st.columns(2)
if st.session_state['inference_mode'] == 'Deep Thinking':
with cols[1]:
render_message(deepthink_message, idx, True)
with cols[0]:
render_message(message, idx, False)
else:
with cols[0]:
render_message(message, idx, False)
with cols[1]:
render_message(deepthink_message, idx, True)
else:
if st.session_state['inference_mode'] == 'Deep Thinking':
if deepthink_message['content'] is not None:
st.markdown(
postprocess(deepthink_message['content'],
deepthink=True))
else:
st.markdown(postprocess(message['content']))
else:
if message['content'] is not None:
st.markdown(postprocess(message['content']))
else:
st.markdown(
postprocess(deepthink_message['content'],
deepthink=True))
# Accept user input
if prompt := st.chat_input('What is up?'):
# Display user message in chat message container
with st.chat_message('user', avatar=user_avator):
st.markdown(prompt)
real_prompt = combine_history(prompt)
st.markdown(postprocess(prompt, add_prefix=False))
real_prompt = combine_history(
prompt,
deepthink=st.session_state['inference_mode'] == 'Deep Thinking')
# Add user message to chat history
st.session_state.messages.append({
'role': 'user',
'content': prompt,
'avatar': user_avator
})
st.session_state.deepthink_messages.append({
'role': 'user',
'content': prompt,
'avatar': user_avator
})
with st.chat_message('robot', avatar=robot_avator):
st.toggle('compare',
key=f'compare_{len(st.session_state.messages)}')
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
@ -283,12 +408,27 @@ def main():
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(cur_response + '')
message_placeholder.markdown(cur_response)
message_placeholder.markdown(
postprocess(cur_response,
deepthink=st.session_state['inference_mode'] ==
'Deep Thinking') + '')
message_placeholder.markdown(
postprocess(cur_response,
deepthink=st.session_state['inference_mode'] ==
'Deep Thinking'))
# Add robot response to chat history
response, deepthink_response = ((None, cur_response)
if st.session_state['inference_mode']
== 'Deep Thinking' else
(cur_response, None))
st.session_state.messages.append({
'role': 'robot',
'content': cur_response, # pylint: disable=undefined-loop-variable
'content': response, # pylint: disable=undefined-loop-variable
'avatar': robot_avator,
})
st.session_state.deepthink_messages.append({
'role': 'robot',
'content': deepthink_response,
'avatar': robot_avator,
})
torch.cuda.empty_cache()