mirror of https://github.com/InternLM/InternLM
update
parent
b9ce4fd8b0
commit
99e5cd62e0
Binary file not shown.
After Width: | Height: | Size: 300 KiB |
|
@ -79,6 +79,6 @@ pip install transformers>=4.48
|
||||||
streamlit run ./chat/web_demo.py
|
streamlit run ./chat/web_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
The effect is similar to below:
|
It supports switching between different inference modes and comparing their responses.
|
||||||
|
|
||||||

|

|
||||||
|
|
|
@ -76,3 +76,7 @@ pip install streamlit
|
||||||
pip install transformers>=4.48
|
pip install transformers>=4.48
|
||||||
streamlit run ./web_demo.py
|
streamlit run ./web_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
支持切换不同推理模式,并比较它们的回复
|
||||||
|
|
||||||
|

|
||||||
|
|
219
chat/web_demo.py
219
chat/web_demo.py
|
@ -54,7 +54,8 @@ def generate_interactive(
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||||
|
List[int]]] = None,
|
||||||
additional_eos_token_id: Optional[int] = None,
|
additional_eos_token_id: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -67,7 +68,10 @@ def generate_interactive(
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = model.generation_config
|
generation_config = model.generation_config
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
generation_config._eos_token_tensor = generation_config.eos_token_id
|
||||||
model_kwargs = generation_config.update(**kwargs)
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
|
if generation_config.temperature == 0.0:
|
||||||
|
generation_config.do_sample = False
|
||||||
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||||
generation_config.bos_token_id,
|
generation_config.bos_token_id,
|
||||||
generation_config.eos_token_id,
|
generation_config.eos_token_id,
|
||||||
|
@ -76,7 +80,8 @@ def generate_interactive(
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
if additional_eos_token_id is not None:
|
if additional_eos_token_id is not None:
|
||||||
eos_token_id.append(additional_eos_token_id)
|
eos_token_id.append(additional_eos_token_id)
|
||||||
has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get(
|
||||||
|
'max_length') is None and generation_config.max_length is not None
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Using 'max_length''s default \
|
f"Using 'max_length''s default \
|
||||||
|
@ -107,12 +112,12 @@ def generate_interactive(
|
||||||
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
|
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
|
||||||
f"but 'max_length' is set to {generation_config.max_length}. "
|
f"but 'max_length' is set to {generation_config.max_length}. "
|
||||||
'This can lead to unexpected behavior. You should consider'
|
'This can lead to unexpected behavior. You should consider'
|
||||||
" increasing 'max_new_tokens'."
|
" increasing 'max_new_tokens'.")
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
)
|
||||||
|
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
logits_processor = model._get_logits_processor(
|
logits_processor = model._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
@ -122,19 +127,20 @@ def generate_interactive(
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
stopping_criteria = model._get_stopping_criteria(
|
# stopping_criteria = model._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
# generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
# )
|
||||||
|
|
||||||
if transformers.__version__ >= '4.42.0':
|
# if transformers.__version__ >= '4.42.0':
|
||||||
logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
||||||
else:
|
# else:
|
||||||
logits_warper = model._get_logits_warper(generation_config)
|
# logits_warper = model._get_logits_warper(generation_config)
|
||||||
|
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
scores = None
|
scores = None
|
||||||
while True:
|
while True:
|
||||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = model.prepare_inputs_for_generation(
|
||||||
|
input_ids, **model_kwargs)
|
||||||
# forward pass to get next token
|
# forward pass to get next token
|
||||||
outputs = model(
|
outputs = model(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
|
@ -147,7 +153,7 @@ def generate_interactive(
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
# next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
@ -158,8 +164,9 @@ def generate_interactive(
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||||
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
(min(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
output_token_ids = input_ids[0].cpu().tolist()
|
output_token_ids = input_ids[0].cpu().tolist()
|
||||||
output_token_ids = output_token_ids[input_length:]
|
output_token_ids = output_token_ids[input_length:]
|
||||||
|
@ -171,38 +178,54 @@ def generate_interactive(
|
||||||
yield response
|
yield response
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
# or if we exceed the maximum length
|
# or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
|
# break
|
||||||
|
if unfinished_sequences.max() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def on_btn_click():
|
def on_btn_click():
|
||||||
del st.session_state.messages
|
del st.session_state.messages
|
||||||
del st.session_state.deepthink_messages
|
del st.session_state.deepthink_messages
|
||||||
del st.session_state.deep_mode
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess(text):
|
def postprocess(text, add_prefix=True, deepthink=False):
|
||||||
text = re.sub(r'\\\(|\\\)', r'$', text)
|
text = re.sub(r'\\\(|\\\)', r'$', text)
|
||||||
text = re.sub(r'\\\[|\\\]', r'$$', text)
|
text = re.sub(r'\\\[|\\\]', r'$$', text)
|
||||||
|
if add_prefix:
|
||||||
|
text = (':red[[Deep Thinking]]\n\n'
|
||||||
|
if deepthink else ':blue[[Normal Response]]\n\n') + text
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_model():
|
def load_model():
|
||||||
model_path = 'internlm/internlm2_5-7b-chat'
|
model_path = 'internlm/internlm3-8b-instruct'
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
trust_remote_code=True).to(
|
||||||
|
torch.bfloat16).cuda()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||||
|
trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_generation_config():
|
def prepare_generation_config():
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
max_length = st.slider('Max Length', min_value=8, max_value=32768, value=32768)
|
max_length = st.slider('Max Length',
|
||||||
|
min_value=8,
|
||||||
|
max_value=32768,
|
||||||
|
value=32768)
|
||||||
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
|
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
|
||||||
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
|
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
|
||||||
|
radio = st.radio('Inference Mode',
|
||||||
|
['Normal Response', 'Deep Thinking'],
|
||||||
|
key='mode')
|
||||||
st.button('Clear Chat History', on_click=on_btn_click)
|
st.button('Clear Chat History', on_click=on_btn_click)
|
||||||
|
|
||||||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
st.session_state['inference_mode'] = radio
|
||||||
|
generation_config = GenerationConfig(max_length=max_length,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature)
|
||||||
|
|
||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
|
@ -220,16 +243,21 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
|
||||||
stop = len(st.session_state.messages) + stop
|
stop = len(st.session_state.messages) + stop
|
||||||
messages = []
|
messages = []
|
||||||
for idx in range(start, stop):
|
for idx in range(start, stop):
|
||||||
message, deepthink_message = st.session_state.messages[idx], st.session_state.deepthink_messages[idx]
|
message, deepthink_message = st.session_state.messages[
|
||||||
if deepthink and deepthink_message['content'] is not None:
|
idx], st.session_state.deepthink_messages[idx]
|
||||||
|
if deepthink:
|
||||||
|
if deepthink_message['content'] is not None:
|
||||||
messages.append(deepthink_message)
|
messages.append(deepthink_message)
|
||||||
else:
|
else:
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
meta_instruction = (
|
else:
|
||||||
'You are InternLM (书生·浦语), a helpful, honest, '
|
if message['content'] is not None:
|
||||||
|
messages.append(message)
|
||||||
|
else:
|
||||||
|
messages.append(deepthink_message)
|
||||||
|
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
|
||||||
'and harmless AI assistant developed by Shanghai '
|
'and harmless AI assistant developed by Shanghai '
|
||||||
'AI Laboratory (上海人工智能实验室).'
|
'AI Laboratory (上海人工智能实验室).')
|
||||||
)
|
|
||||||
if deepthink:
|
if deepthink:
|
||||||
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||||
## Deep Understanding
|
## Deep Understanding
|
||||||
|
@ -299,67 +327,98 @@ def main():
|
||||||
user_avator = 'assets/user.png'
|
user_avator = 'assets/user.png'
|
||||||
robot_avator = 'assets/robot.png'
|
robot_avator = 'assets/robot.png'
|
||||||
|
|
||||||
st.title('internlm3-8b-chat')
|
st.title('InternLM3-8B-Instruct')
|
||||||
|
|
||||||
generation_config = prepare_generation_config()
|
generation_config = prepare_generation_config()
|
||||||
|
|
||||||
|
def render_message(msg, msg_idx, deepthink):
|
||||||
|
if msg['content'] is None:
|
||||||
|
real_prompt = combine_history(
|
||||||
|
st.session_state.messages[msg_idx - 1]['content'],
|
||||||
|
deepthink=deepthink,
|
||||||
|
stop=msg_idx - 1)
|
||||||
|
placeholder = st.empty()
|
||||||
|
for cur_response in generate_interactive(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=real_prompt,
|
||||||
|
additional_eos_token_id=92542,
|
||||||
|
**asdict(generation_config),
|
||||||
|
):
|
||||||
|
placeholder.markdown(
|
||||||
|
postprocess(cur_response, deepthink=deepthink) + '▌')
|
||||||
|
placeholder.markdown(postprocess(cur_response,
|
||||||
|
deepthink=deepthink))
|
||||||
|
msg['content'] = cur_response
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
st.markdown(postprocess(msg['content'], deepthink=deepthink))
|
||||||
|
|
||||||
# Initialize chat history
|
# Initialize chat history
|
||||||
if 'messages' not in st.session_state:
|
if 'messages' not in st.session_state:
|
||||||
st.session_state.messages = []
|
st.session_state.messages = []
|
||||||
if 'deepthink_messages' not in st.session_state:
|
if 'deepthink_messages' not in st.session_state:
|
||||||
st.session_state.deepthink_messages = []
|
st.session_state.deepthink_messages = []
|
||||||
if 'deep_mode' not in st.session_state:
|
|
||||||
st.session_state.deep_mode = {}
|
|
||||||
|
|
||||||
# Display chat messages from history on app rerun
|
# Display chat messages from history on app rerun
|
||||||
for idx, (message, deepthink_message) in enumerate(
|
for idx, (message, deepthink_message) in enumerate(
|
||||||
zip(st.session_state.messages, st.session_state.deepthink_messages)
|
zip(st.session_state.messages,
|
||||||
):
|
st.session_state.deepthink_messages)):
|
||||||
with st.chat_message(message['role'], avatar=message.get('avatar')):
|
with st.chat_message(message['role'], avatar=message.get('avatar')):
|
||||||
if message['role'] == 'user':
|
if message['role'] == 'user':
|
||||||
st.markdown(postprocess(message['content']))
|
st.markdown(postprocess(message['content'], add_prefix=False))
|
||||||
else:
|
else:
|
||||||
if st.button('深度思考', key=f'deep_mode_{idx}'):
|
if st.toggle('compare', key=f'compare_{idx}'):
|
||||||
st.session_state.deep_mode[idx] = not st.session_state.deep_mode.get(idx, False)
|
|
||||||
if st.session_state.deep_mode.get(idx, False):
|
|
||||||
cols = st.columns(2)
|
cols = st.columns(2)
|
||||||
with cols[0]:
|
if st.session_state['inference_mode'] == 'Deep Thinking':
|
||||||
st.markdown(postprocess(message['content']))
|
|
||||||
with cols[1]:
|
with cols[1]:
|
||||||
if deepthink_message['content'] is None:
|
render_message(deepthink_message, idx, True)
|
||||||
real_prompt = combine_history(
|
with cols[0]:
|
||||||
st.session_state.deepthink_messages[idx - 1]['content'], deepthink=True, stop=idx - 1
|
render_message(message, idx, False)
|
||||||
)
|
|
||||||
message_placeholder = st.empty()
|
|
||||||
for cur_response in generate_interactive(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
prompt=real_prompt,
|
|
||||||
additional_eos_token_id=92542,
|
|
||||||
**asdict(generation_config),
|
|
||||||
):
|
|
||||||
# Display robot response in chat message container
|
|
||||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
|
||||||
message_placeholder.markdown(postprocess(cur_response))
|
|
||||||
deepthink_message['content'] = cur_response
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
st.markdown(postprocess(deepthink_message['content']))
|
with cols[0]:
|
||||||
|
render_message(message, idx, False)
|
||||||
|
with cols[1]:
|
||||||
|
render_message(deepthink_message, idx, True)
|
||||||
|
else:
|
||||||
|
if st.session_state['inference_mode'] == 'Deep Thinking':
|
||||||
|
if deepthink_message['content'] is not None:
|
||||||
|
st.markdown(
|
||||||
|
postprocess(deepthink_message['content'],
|
||||||
|
deepthink=True))
|
||||||
else:
|
else:
|
||||||
st.markdown(postprocess(message['content']))
|
st.markdown(postprocess(message['content']))
|
||||||
|
else:
|
||||||
|
if message['content'] is not None:
|
||||||
|
st.markdown(postprocess(message['content']))
|
||||||
|
else:
|
||||||
|
st.markdown(
|
||||||
|
postprocess(deepthink_message['content'],
|
||||||
|
deepthink=True))
|
||||||
|
|
||||||
# Accept user input
|
# Accept user input
|
||||||
if prompt := st.chat_input('What is up?'):
|
if prompt := st.chat_input('What is up?'):
|
||||||
# Display user message in chat message container
|
# Display user message in chat message container
|
||||||
with st.chat_message('user', avatar=user_avator):
|
with st.chat_message('user', avatar=user_avator):
|
||||||
st.markdown(postprocess(prompt))
|
st.markdown(postprocess(prompt, add_prefix=False))
|
||||||
real_prompt = combine_history(prompt)
|
real_prompt = combine_history(
|
||||||
|
prompt,
|
||||||
|
deepthink=st.session_state['inference_mode'] == 'Deep Thinking')
|
||||||
# Add user message to chat history
|
# Add user message to chat history
|
||||||
st.session_state.messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
st.session_state.messages.append({
|
||||||
st.session_state.deepthink_messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
'role': 'user',
|
||||||
|
'content': prompt,
|
||||||
|
'avatar': user_avator
|
||||||
|
})
|
||||||
|
st.session_state.deepthink_messages.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': prompt,
|
||||||
|
'avatar': user_avator
|
||||||
|
})
|
||||||
|
|
||||||
with st.chat_message('robot', avatar=robot_avator):
|
with st.chat_message('robot', avatar=robot_avator):
|
||||||
st.button('深度思考', key=f'deep_mode_{len(st.session_state.messages)}')
|
st.toggle('compare',
|
||||||
|
key=f'compare_{len(st.session_state.messages)}')
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
for cur_response in generate_interactive(
|
for cur_response in generate_interactive(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -369,23 +428,29 @@ def main():
|
||||||
**asdict(generation_config),
|
**asdict(generation_config),
|
||||||
):
|
):
|
||||||
# Display robot response in chat message container
|
# Display robot response in chat message container
|
||||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
message_placeholder.markdown(
|
||||||
message_placeholder.markdown(postprocess(cur_response))
|
postprocess(cur_response,
|
||||||
|
deepthink=st.session_state['inference_mode'] ==
|
||||||
|
'Deep Thinking') + '▌')
|
||||||
|
message_placeholder.markdown(
|
||||||
|
postprocess(cur_response,
|
||||||
|
deepthink=st.session_state['inference_mode'] ==
|
||||||
|
'Deep Thinking'))
|
||||||
# Add robot response to chat history
|
# Add robot response to chat history
|
||||||
st.session_state.messages.append(
|
response, deepthink_response = ((None, cur_response)
|
||||||
{
|
if st.session_state['inference_mode']
|
||||||
|
== 'Deep Thinking' else
|
||||||
|
(cur_response, None))
|
||||||
|
st.session_state.messages.append({
|
||||||
'role': 'robot',
|
'role': 'robot',
|
||||||
'content': cur_response, # pylint: disable=undefined-loop-variable
|
'content': response, # pylint: disable=undefined-loop-variable
|
||||||
'avatar': robot_avator,
|
'avatar': robot_avator,
|
||||||
}
|
})
|
||||||
)
|
st.session_state.deepthink_messages.append({
|
||||||
st.session_state.deepthink_messages.append(
|
|
||||||
{
|
|
||||||
'role': 'robot',
|
'role': 'robot',
|
||||||
'content': None,
|
'content': deepthink_response,
|
||||||
'avatar': robot_avator,
|
'avatar': robot_avator,
|
||||||
}
|
})
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue