mirror of https://github.com/InternLM/InternLM
update
parent
b9ce4fd8b0
commit
99e5cd62e0
Binary file not shown.
After Width: | Height: | Size: 300 KiB |
|
@ -79,6 +79,6 @@ pip install transformers>=4.48
|
|||
streamlit run ./chat/web_demo.py
|
||||
```
|
||||
|
||||
The effect is similar to below:
|
||||
It supports switching between different inference modes and comparing their responses.
|
||||
|
||||

|
||||

|
||||
|
|
|
@ -76,3 +76,7 @@ pip install streamlit
|
|||
pip install transformers>=4.48
|
||||
streamlit run ./web_demo.py
|
||||
```
|
||||
|
||||
支持切换不同推理模式,并比较它们的回复
|
||||
|
||||

|
||||
|
|
249
chat/web_demo.py
249
chat/web_demo.py
|
@ -54,7 +54,8 @@ def generate_interactive(
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||
List[int]]] = None,
|
||||
additional_eos_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -67,7 +68,10 @@ def generate_interactive(
|
|||
if generation_config is None:
|
||||
generation_config = model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
generation_config._eos_token_tensor = generation_config.eos_token_id
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
if generation_config.temperature == 0.0:
|
||||
generation_config.do_sample = False
|
||||
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
|
@ -76,7 +80,8 @@ def generate_interactive(
|
|||
eos_token_id = [eos_token_id]
|
||||
if additional_eos_token_id is not None:
|
||||
eos_token_id.append(additional_eos_token_id)
|
||||
has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None
|
||||
has_default_max_length = kwargs.get(
|
||||
'max_length') is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using 'max_length''s default \
|
||||
|
@ -107,12 +112,12 @@ def generate_interactive(
|
|||
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
|
||||
f"but 'max_length' is set to {generation_config.max_length}. "
|
||||
'This can lead to unexpected behavior. You should consider'
|
||||
" increasing 'max_new_tokens'."
|
||||
)
|
||||
" increasing 'max_new_tokens'.")
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
|
||||
)
|
||||
# stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
|
@ -122,19 +127,20 @@ def generate_interactive(
|
|||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
# stopping_criteria = model._get_stopping_criteria(
|
||||
# generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
# )
|
||||
|
||||
if transformers.__version__ >= '4.42.0':
|
||||
logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
||||
else:
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
# if transformers.__version__ >= '4.42.0':
|
||||
# logits_warper = model._get_logits_warper(generation_config, device='cuda')
|
||||
# else:
|
||||
# logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
|
@ -147,7 +153,7 @@ def generate_interactive(
|
|||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
# next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
@ -158,8 +164,9 @@ def generate_interactive(
|
|||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
||||
# model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
(min(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
output_token_ids = input_ids[0].cpu().tolist()
|
||||
output_token_ids = output_token_ids[input_length:]
|
||||
|
@ -171,38 +178,54 @@ def generate_interactive(
|
|||
yield response
|
||||
# stop when each sentence is finished
|
||||
# or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
# if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
# break
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
|
||||
def on_btn_click():
|
||||
del st.session_state.messages
|
||||
del st.session_state.deepthink_messages
|
||||
del st.session_state.deep_mode
|
||||
|
||||
|
||||
def postprocess(text):
|
||||
def postprocess(text, add_prefix=True, deepthink=False):
|
||||
text = re.sub(r'\\\(|\\\)', r'$', text)
|
||||
text = re.sub(r'\\\[|\\\]', r'$$', text)
|
||||
if add_prefix:
|
||||
text = (':red[[Deep Thinking]]\n\n'
|
||||
if deepthink else ':blue[[Normal Response]]\n\n') + text
|
||||
return text
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model_path = 'internlm/internlm2_5-7b-chat'
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model_path = 'internlm/internlm3-8b-instruct'
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
trust_remote_code=True).to(
|
||||
torch.bfloat16).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||
trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_generation_config():
|
||||
with st.sidebar:
|
||||
max_length = st.slider('Max Length', min_value=8, max_value=32768, value=32768)
|
||||
max_length = st.slider('Max Length',
|
||||
min_value=8,
|
||||
max_value=32768,
|
||||
value=32768)
|
||||
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
|
||||
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
|
||||
radio = st.radio('Inference Mode',
|
||||
['Normal Response', 'Deep Thinking'],
|
||||
key='mode')
|
||||
st.button('Clear Chat History', on_click=on_btn_click)
|
||||
|
||||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
||||
st.session_state['inference_mode'] = radio
|
||||
generation_config = GenerationConfig(max_length=max_length,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
@ -220,16 +243,21 @@ def combine_history(prompt, deepthink=False, start=0, stop=None):
|
|||
stop = len(st.session_state.messages) + stop
|
||||
messages = []
|
||||
for idx in range(start, stop):
|
||||
message, deepthink_message = st.session_state.messages[idx], st.session_state.deepthink_messages[idx]
|
||||
if deepthink and deepthink_message['content'] is not None:
|
||||
messages.append(deepthink_message)
|
||||
message, deepthink_message = st.session_state.messages[
|
||||
idx], st.session_state.deepthink_messages[idx]
|
||||
if deepthink:
|
||||
if deepthink_message['content'] is not None:
|
||||
messages.append(deepthink_message)
|
||||
else:
|
||||
messages.append(message)
|
||||
else:
|
||||
messages.append(message)
|
||||
meta_instruction = (
|
||||
'You are InternLM (书生·浦语), a helpful, honest, '
|
||||
'and harmless AI assistant developed by Shanghai '
|
||||
'AI Laboratory (上海人工智能实验室).'
|
||||
)
|
||||
if message['content'] is not None:
|
||||
messages.append(message)
|
||||
else:
|
||||
messages.append(deepthink_message)
|
||||
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
|
||||
'and harmless AI assistant developed by Shanghai '
|
||||
'AI Laboratory (上海人工智能实验室).')
|
||||
if deepthink:
|
||||
meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:
|
||||
## Deep Understanding
|
||||
|
@ -299,93 +327,130 @@ def main():
|
|||
user_avator = 'assets/user.png'
|
||||
robot_avator = 'assets/robot.png'
|
||||
|
||||
st.title('internlm3-8b-chat')
|
||||
st.title('InternLM3-8B-Instruct')
|
||||
|
||||
generation_config = prepare_generation_config()
|
||||
|
||||
def render_message(msg, msg_idx, deepthink):
|
||||
if msg['content'] is None:
|
||||
real_prompt = combine_history(
|
||||
st.session_state.messages[msg_idx - 1]['content'],
|
||||
deepthink=deepthink,
|
||||
stop=msg_idx - 1)
|
||||
placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
placeholder.markdown(
|
||||
postprocess(cur_response, deepthink=deepthink) + '▌')
|
||||
placeholder.markdown(postprocess(cur_response,
|
||||
deepthink=deepthink))
|
||||
msg['content'] = cur_response
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
st.markdown(postprocess(msg['content'], deepthink=deepthink))
|
||||
|
||||
# Initialize chat history
|
||||
if 'messages' not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if 'deepthink_messages' not in st.session_state:
|
||||
st.session_state.deepthink_messages = []
|
||||
if 'deep_mode' not in st.session_state:
|
||||
st.session_state.deep_mode = {}
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for idx, (message, deepthink_message) in enumerate(
|
||||
zip(st.session_state.messages, st.session_state.deepthink_messages)
|
||||
):
|
||||
zip(st.session_state.messages,
|
||||
st.session_state.deepthink_messages)):
|
||||
with st.chat_message(message['role'], avatar=message.get('avatar')):
|
||||
if message['role'] == 'user':
|
||||
st.markdown(postprocess(message['content']))
|
||||
st.markdown(postprocess(message['content'], add_prefix=False))
|
||||
else:
|
||||
if st.button('深度思考', key=f'deep_mode_{idx}'):
|
||||
st.session_state.deep_mode[idx] = not st.session_state.deep_mode.get(idx, False)
|
||||
if st.session_state.deep_mode.get(idx, False):
|
||||
if st.toggle('compare', key=f'compare_{idx}'):
|
||||
cols = st.columns(2)
|
||||
with cols[0]:
|
||||
st.markdown(postprocess(message['content']))
|
||||
with cols[1]:
|
||||
if deepthink_message['content'] is None:
|
||||
real_prompt = combine_history(
|
||||
st.session_state.deepthink_messages[idx - 1]['content'], deepthink=True, stop=idx - 1
|
||||
)
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
||||
message_placeholder.markdown(postprocess(cur_response))
|
||||
deepthink_message['content'] = cur_response
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
st.markdown(postprocess(deepthink_message['content']))
|
||||
if st.session_state['inference_mode'] == 'Deep Thinking':
|
||||
with cols[1]:
|
||||
render_message(deepthink_message, idx, True)
|
||||
with cols[0]:
|
||||
render_message(message, idx, False)
|
||||
else:
|
||||
with cols[0]:
|
||||
render_message(message, idx, False)
|
||||
with cols[1]:
|
||||
render_message(deepthink_message, idx, True)
|
||||
else:
|
||||
st.markdown(postprocess(message['content']))
|
||||
if st.session_state['inference_mode'] == 'Deep Thinking':
|
||||
if deepthink_message['content'] is not None:
|
||||
st.markdown(
|
||||
postprocess(deepthink_message['content'],
|
||||
deepthink=True))
|
||||
else:
|
||||
st.markdown(postprocess(message['content']))
|
||||
else:
|
||||
if message['content'] is not None:
|
||||
st.markdown(postprocess(message['content']))
|
||||
else:
|
||||
st.markdown(
|
||||
postprocess(deepthink_message['content'],
|
||||
deepthink=True))
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input('What is up?'):
|
||||
# Display user message in chat message container
|
||||
with st.chat_message('user', avatar=user_avator):
|
||||
st.markdown(postprocess(prompt))
|
||||
real_prompt = combine_history(prompt)
|
||||
st.markdown(postprocess(prompt, add_prefix=False))
|
||||
real_prompt = combine_history(
|
||||
prompt,
|
||||
deepthink=st.session_state['inference_mode'] == 'Deep Thinking')
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
||||
st.session_state.deepthink_messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator})
|
||||
st.session_state.messages.append({
|
||||
'role': 'user',
|
||||
'content': prompt,
|
||||
'avatar': user_avator
|
||||
})
|
||||
st.session_state.deepthink_messages.append({
|
||||
'role': 'user',
|
||||
'content': prompt,
|
||||
'avatar': user_avator
|
||||
})
|
||||
|
||||
with st.chat_message('robot', avatar=robot_avator):
|
||||
st.button('深度思考', key=f'deep_mode_{len(st.session_state.messages)}')
|
||||
st.toggle('compare',
|
||||
key=f'compare_{len(st.session_state.messages)}')
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(postprocess(cur_response) + '▌')
|
||||
message_placeholder.markdown(postprocess(cur_response))
|
||||
message_placeholder.markdown(
|
||||
postprocess(cur_response,
|
||||
deepthink=st.session_state['inference_mode'] ==
|
||||
'Deep Thinking') + '▌')
|
||||
message_placeholder.markdown(
|
||||
postprocess(cur_response,
|
||||
deepthink=st.session_state['inference_mode'] ==
|
||||
'Deep Thinking'))
|
||||
# Add robot response to chat history
|
||||
st.session_state.messages.append(
|
||||
{
|
||||
'role': 'robot',
|
||||
'content': cur_response, # pylint: disable=undefined-loop-variable
|
||||
'avatar': robot_avator,
|
||||
}
|
||||
)
|
||||
st.session_state.deepthink_messages.append(
|
||||
{
|
||||
'role': 'robot',
|
||||
'content': None,
|
||||
'avatar': robot_avator,
|
||||
}
|
||||
)
|
||||
response, deepthink_response = ((None, cur_response)
|
||||
if st.session_state['inference_mode']
|
||||
== 'Deep Thinking' else
|
||||
(cur_response, None))
|
||||
st.session_state.messages.append({
|
||||
'role': 'robot',
|
||||
'content': response, # pylint: disable=undefined-loop-variable
|
||||
'avatar': robot_avator,
|
||||
})
|
||||
st.session_state.deepthink_messages.append({
|
||||
'role': 'robot',
|
||||
'content': deepthink_response,
|
||||
'avatar': robot_avator,
|
||||
})
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue