pull/820/head
braisedpork1964 2025-01-15 06:28:52 +00:00
parent b9ce4fd8b0
commit 99e5cd62e0
4 changed files with 163 additions and 94 deletions

BIN
assets/web_demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

View File

@ -79,6 +79,6 @@ pip install transformers>=4.48
streamlit run ./chat/web_demo.py
```
The effect is similar to below:
It supports switching between different inference modes and comparing their responses.
![demo](https://github.com/InternLM/InternLM/assets/9102141/11b60ee0-47e4-42c0-8278-3051b2f17fe4)
![demo](../assets/web_demo.png)

View File

@ -76,3 +76,7 @@ pip install streamlit
pip install transformers>=4.48
streamlit run ./web_demo.py
```
支持切换不同推理模式,并比较它们的回复
![demo](../assets/web_demo.png)

View File

@ -54,7 +54,8 @@ def generate_interactive(
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
@ -67,7 +68,10 @@ def generate_interactive(
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
generation_config._eos_token_tensor = generation_config.eos_token_id
model_kwargs = generation_config.update(**kwargs)
if generation_config.temperature == 0.0:
generation_config.do_sample = False
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
@ -76,7 +80,8 @@ def generate_interactive(
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None
has_default_max_length = kwargs.get(
'max_length') is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using 'max_length''s default \
@ -107,12 +112,12 @@ def generate_interactive(
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
f"but 'max_length' is set to {generation_config.max_length}. "
'This can lead to unexpected behavior. You should consider'
" increasing 'max_new_tokens'."
)
" increasing 'max_new_tokens'.")
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
)
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
@ -122,19 +127,20 @@ def generate_interactive(
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# stopping_criteria = model._get_stopping_criteria(
# generation_config=generation_config, stopping_criteria=stopping_criteria
# )
if transformers.__version__ >= '4.42.0':
logits_warper = model._get_logits_warper(generation_config, device='cuda')
else:
logits_warper = model._get_logits_warper(generation_config)
# if transformers.__version__ >= '4.42.0':
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
# else:
# logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
@ -147,7 +153,7 @@ def generate_interactive(
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -158,8 +164,9 @@ def generate_interactive(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
@ -171,38 +178,54 @@ def generate_interactive(
yield response
# stop when each sentence is finished
# or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
# break
if unfinished_sequences.max() == 0:
break
def on_btn_click():
del st.session_state.messages
del st.session_state.deepthink_messages
del st.session_state.deep_mode
def postprocess(text):
def postprocess(text, add_prefix=True, deepthink=False):
text = re.sub(r'\\\(|\\\)', r'$', text)
text = re.sub(r'\\\[|\\\]', r'$$', text)
if add_prefix:
text = (':red[[Deep Thinking]]\n\n'
if deepthink else ':blue[[Normal Response]]\n\n') + text
return text
@st.cache_resource
def load_model():
model_path = 'internlm/internlm2_5-7b-chat'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model_path = 'internlm/internlm3-8b-instruct'
model = AutoModelForCausalLM.from_pretrained(model_path,
trust_remote_code=True).to(
torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_path,
trust_remote_code=True)
return model, tokenizer
def prepare_generation_config():
with st.sidebar:
max_length = st.slider('Max Length', min_value=8, max_value=32768, value=32768)
max_length = st.slider('Max Length',
min_value=8,
max_value=32768,
value=32768)
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
radio = st.radio('Inference Mode',
['Normal Response', 'Deep Thinking'],
key='mode')
st.button('Clear Chat History', on_click=on_btn_click)
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
st.session_state['inference_mode'] = radio
generation_config = GenerationConfig(max_length=max_length,
top_p=top_p,
temperature=temperature)
return generation_config
@ -220,16 +243,21 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
stop = len(st.session_state.messages) + stop
messages = []
for idx in range(start, stop):
message, deepthink_message = st.session_state.messages[idx], st.session_state.deepthink_messages[idx]
if deepthink and deepthink_message['content'] is not None:
messages.append(deepthink_message)
message, deepthink_message = st.session_state.messages[
idx], st.session_state.deepthink_messages[idx]
if deepthink:
if deepthink_message['content'] is not None:
messages.append(deepthink_message)
else:
messages.append(message)
else:
messages.append(message)
meta_instruction = (
'You are InternLM (书生·浦语), a helpful, honest, '
'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).'
)
if message['content'] is not None:
messages.append(message)
else:
messages.append(deepthink_message)
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).')
if deepthink:
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
## Deep Understanding
@ -299,93 +327,130 @@ def main():
user_avator = 'assets/user.png'
robot_avator = 'assets/robot.png'
st.title('internlm3-8b-chat')
st.title('InternLM3-8B-Instruct')
generation_config = prepare_generation_config()
def render_message(msg, msg_idx, deepthink):
if msg['content'] is None:
real_prompt = combine_history(
st.session_state.messages[msg_idx - 1]['content'],
deepthink=deepthink,
stop=msg_idx - 1)
placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
placeholder.markdown(
postprocess(cur_response, deepthink=deepthink) + '')
placeholder.markdown(postprocess(cur_response,
deepthink=deepthink))
msg['content'] = cur_response
torch.cuda.empty_cache()
else:
st.markdown(postprocess(msg['content'], deepthink=deepthink))
# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'deepthink_messages' not in st.session_state:
st.session_state.deepthink_messages = []
if 'deep_mode' not in st.session_state:
st.session_state.deep_mode = {}
# Display chat messages from history on app rerun
for idx, (message, deepthink_message) in enumerate(
zip(st.session_state.messages, st.session_state.deepthink_messages)
):
zip(st.session_state.messages,
st.session_state.deepthink_messages)):
with st.chat_message(message['role'], avatar=message.get('avatar')):
if message['role'] == 'user':
st.markdown(postprocess(message['content']))
st.markdown(postprocess(message['content'], add_prefix=False))
else:
if st.button('深度思考', key=f'deep_mode_{idx}'):
st.session_state.deep_mode[idx] = not st.session_state.deep_mode.get(idx, False)
if st.session_state.deep_mode.get(idx, False):
if st.toggle('compare', key=f'compare_{idx}'):
cols = st.columns(2)
with cols[0]:
st.markdown(postprocess(message['content']))
with cols[1]:
if deepthink_message['content'] is None:
real_prompt = combine_history(
st.session_state.deepthink_messages[idx - 1]['content'], deepthink=True, stop=idx - 1
)
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(postprocess(cur_response) + '')
message_placeholder.markdown(postprocess(cur_response))
deepthink_message['content'] = cur_response
torch.cuda.empty_cache()
else:
st.markdown(postprocess(deepthink_message['content']))
if st.session_state['inference_mode'] == 'Deep Thinking':
with cols[1]:
render_message(deepthink_message, idx, True)
with cols[0]:
render_message(message, idx, False)
else:
with cols[0]:
render_message(message, idx, False)
with cols[1]:
render_message(deepthink_message, idx, True)
else:
st.markdown(postprocess(message['content']))
if st.session_state['inference_mode'] == 'Deep Thinking':
if deepthink_message['content'] is not None:
st.markdown(
postprocess(deepthink_message['content'],
deepthink=True))
else:
st.markdown(postprocess(message['content']))
else:
if message['content'] is not None:
st.markdown(postprocess(message['content']))
else:
st.markdown(
postprocess(deepthink_message['content'],
deepthink=True))
# Accept user input
if prompt := st.chat_input('What is up?'):
# Display user message in chat message container
with st.chat_message('user', avatar=user_avator):
st.markdown(postprocess(prompt))
real_prompt = combine_history(prompt)
st.markdown(postprocess(prompt, add_prefix=False))
real_prompt = combine_history(
prompt,
deepthink=st.session_state['inference_mode'] == 'Deep Thinking')
# Add user message to chat history
st.session_state.messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
st.session_state.deepthink_messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
st.session_state.messages.append({
'role': 'user',
'content': prompt,
'avatar': user_avator
})
st.session_state.deepthink_messages.append({
'role': 'user',
'content': prompt,
'avatar': user_avator
})
with st.chat_message('robot', avatar=robot_avator):
st.button('深度思考', key=f'deep_mode_{len(st.session_state.messages)}')
st.toggle('compare',
key=f'compare_{len(st.session_state.messages)}')
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(postprocess(cur_response) + '')
message_placeholder.markdown(postprocess(cur_response))
message_placeholder.markdown(
postprocess(cur_response,
deepthink=st.session_state['inference_mode'] ==
'Deep Thinking') + '')
message_placeholder.markdown(
postprocess(cur_response,
deepthink=st.session_state['inference_mode'] ==
'Deep Thinking'))
# Add robot response to chat history
st.session_state.messages.append(
{
'role': 'robot',
'content': cur_response, # pylint: disable=undefined-loop-variable
'avatar': robot_avator,
}
)
st.session_state.deepthink_messages.append(
{
'role': 'robot',
'content': None,
'avatar': robot_avator,
}
)
response, deepthink_response = ((None, cur_response)
if st.session_state['inference_mode']
== 'Deep Thinking' else
(cur_response, None))
st.session_state.messages.append({
'role': 'robot',
'content': response, # pylint: disable=undefined-loop-variable
'avatar': robot_avator,
})
st.session_state.deepthink_messages.append({
'role': 'robot',
'content': deepthink_response,
'avatar': robot_avator,
})
torch.cuda.empty_cache()