diff --git a/assets/web_demo.png b/assets/web_demo.png new file mode 100644 index 0000000..f35cdcd Binary files /dev/null and b/assets/web_demo.png differ diff --git a/chat/README.md b/chat/README.md index 41c6754..c27a06a 100644 --- a/chat/README.md +++ b/chat/README.md @@ -79,6 +79,6 @@ pip install transformers>=4.48 streamlit run ./chat/web_demo.py ``` -The effect is similar to below: +It supports switching between different inference modes and comparing their responses. -![demo](https://github.com/InternLM/InternLM/assets/9102141/11b60ee0-47e4-42c0-8278-3051b2f17fe4) +![demo](../assets/web_demo.png) diff --git a/chat/README_zh-CN.md b/chat/README_zh-CN.md index 56d468c..a515500 100644 --- a/chat/README_zh-CN.md +++ b/chat/README_zh-CN.md @@ -76,3 +76,7 @@ pip install streamlit pip install transformers>=4.48 streamlit run ./web_demo.py ``` + +支持切换不同推理模式,并比较它们的回复 + +![demo](../assets/web_demo.png) diff --git a/chat/web_demo.py b/chat/web_demo.py index 6ee1aef..0470d40 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -54,7 +54,8 @@ def generate_interactive( generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, additional_eos_token_id: Optional[int] = None, **kwargs, ): @@ -67,7 +68,10 @@ def generate_interactive( if generation_config is None: generation_config = model.generation_config generation_config = copy.deepcopy(generation_config) + generation_config._eos_token_tensor = generation_config.eos_token_id model_kwargs = generation_config.update(**kwargs) + if generation_config.temperature == 0.0: + generation_config.do_sample = False bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 generation_config.bos_token_id, generation_config.eos_token_id, @@ -76,7 +80,8 @@ def generate_interactive( eos_token_id = [eos_token_id] if additional_eos_token_id is not None: eos_token_id.append(additional_eos_token_id) - has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using 'max_length''s default \ @@ -107,12 +112,12 @@ def generate_interactive( f'Input length of {input_ids_string} is {input_ids_seq_length}, ' f"but 'max_length' is set to {generation_config.max_length}. " 'This can lead to unexpected behavior. You should consider' - " increasing 'max_new_tokens'." - ) + " increasing 'max_new_tokens'.") # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList( + ) + # stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = model._get_logits_processor( generation_config=generation_config, @@ -122,19 +127,20 @@ def generate_interactive( logits_processor=logits_processor, ) - stopping_criteria = model._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) + # stopping_criteria = model._get_stopping_criteria( + # generation_config=generation_config, stopping_criteria=stopping_criteria + # ) - if transformers.__version__ >= '4.42.0': - logits_warper = model._get_logits_warper(generation_config, device='cuda') - else: - logits_warper = model._get_logits_warper(generation_config) + # if transformers.__version__ >= '4.42.0': + # logits_warper = model._get_logits_warper(generation_config, device='cuda') + # else: + # logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: - model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = model.prepare_inputs_for_generation( + input_ids, **model_kwargs) # forward pass to get next token outputs = model( **model_inputs, @@ -147,7 +153,7 @@ def generate_interactive( # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) + # next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -158,8 +164,9 @@ def generate_interactive( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) - unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) + # model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul( + (min(next_tokens != i for i in eos_token_id)).long()) output_token_ids = input_ids[0].cpu().tolist() output_token_ids = output_token_ids[input_length:] @@ -171,38 +178,54 @@ def generate_interactive( yield response # stop when each sentence is finished # or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + # if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + # break + if unfinished_sequences.max() == 0: break def on_btn_click(): del st.session_state.messages del st.session_state.deepthink_messages - del st.session_state.deep_mode -def postprocess(text): +def postprocess(text, add_prefix=True, deepthink=False): text = re.sub(r'\\\(|\\\)', r'$', text) text = re.sub(r'\\\[|\\\]', r'$$', text) + if add_prefix: + text = (':red[[Deep Thinking]]\n\n' + if deepthink else ':blue[[Normal Response]]\n\n') + text return text @st.cache_resource def load_model(): - model_path = 'internlm/internlm2_5-7b-chat' - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda() - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model_path = 'internlm/internlm3-8b-instruct' + model = AutoModelForCausalLM.from_pretrained(model_path, + trust_remote_code=True).to( + torch.bfloat16).cuda() + tokenizer = AutoTokenizer.from_pretrained(model_path, + trust_remote_code=True) return model, tokenizer def prepare_generation_config(): with st.sidebar: - max_length = st.slider('Max Length', min_value=8, max_value=32768, value=32768) + max_length = st.slider('Max Length', + min_value=8, + max_value=32768, + value=32768) top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01) temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01) + radio = st.radio('Inference Mode', + ['Normal Response', 'Deep Thinking'], + key='mode') st.button('Clear Chat History', on_click=on_btn_click) - generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature) + st.session_state['inference_mode'] = radio + generation_config = GenerationConfig(max_length=max_length, + top_p=top_p, + temperature=temperature) return generation_config @@ -220,16 +243,21 @@ def combine_history(prompt, deepthink=False, start=0, stop=None): stop = len(st.session_state.messages) + stop messages = [] for idx in range(start, stop): - message, deepthink_message = st.session_state.messages[idx], st.session_state.deepthink_messages[idx] - if deepthink and deepthink_message['content'] is not None: - messages.append(deepthink_message) + message, deepthink_message = st.session_state.messages[ + idx], st.session_state.deepthink_messages[idx] + if deepthink: + if deepthink_message['content'] is not None: + messages.append(deepthink_message) + else: + messages.append(message) else: - messages.append(message) - meta_instruction = ( - 'You are InternLM (书生·浦语), a helpful, honest, ' - 'and harmless AI assistant developed by Shanghai ' - 'AI Laboratory (上海人工智能实验室).' - ) + if message['content'] is not None: + messages.append(message) + else: + messages.append(deepthink_message) + meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, ' + 'and harmless AI assistant developed by Shanghai ' + 'AI Laboratory (上海人工智能实验室).') if deepthink: meta_instruction += """\nYou are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes: ## Deep Understanding @@ -299,93 +327,130 @@ def main(): user_avator = 'assets/user.png' robot_avator = 'assets/robot.png' - st.title('internlm3-8b-chat') + st.title('InternLM3-8B-Instruct') generation_config = prepare_generation_config() + def render_message(msg, msg_idx, deepthink): + if msg['content'] is None: + real_prompt = combine_history( + st.session_state.messages[msg_idx - 1]['content'], + deepthink=deepthink, + stop=msg_idx - 1) + placeholder = st.empty() + for cur_response in generate_interactive( + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=92542, + **asdict(generation_config), + ): + placeholder.markdown( + postprocess(cur_response, deepthink=deepthink) + '▌') + placeholder.markdown(postprocess(cur_response, + deepthink=deepthink)) + msg['content'] = cur_response + torch.cuda.empty_cache() + else: + st.markdown(postprocess(msg['content'], deepthink=deepthink)) + # Initialize chat history if 'messages' not in st.session_state: st.session_state.messages = [] if 'deepthink_messages' not in st.session_state: st.session_state.deepthink_messages = [] - if 'deep_mode' not in st.session_state: - st.session_state.deep_mode = {} # Display chat messages from history on app rerun for idx, (message, deepthink_message) in enumerate( - zip(st.session_state.messages, st.session_state.deepthink_messages) - ): + zip(st.session_state.messages, + st.session_state.deepthink_messages)): with st.chat_message(message['role'], avatar=message.get('avatar')): if message['role'] == 'user': - st.markdown(postprocess(message['content'])) + st.markdown(postprocess(message['content'], add_prefix=False)) else: - if st.button('深度思考', key=f'deep_mode_{idx}'): - st.session_state.deep_mode[idx] = not st.session_state.deep_mode.get(idx, False) - if st.session_state.deep_mode.get(idx, False): + if st.toggle('compare', key=f'compare_{idx}'): cols = st.columns(2) - with cols[0]: - st.markdown(postprocess(message['content'])) - with cols[1]: - if deepthink_message['content'] is None: - real_prompt = combine_history( - st.session_state.deepthink_messages[idx - 1]['content'], deepthink=True, stop=idx - 1 - ) - message_placeholder = st.empty() - for cur_response in generate_interactive( - model=model, - tokenizer=tokenizer, - prompt=real_prompt, - additional_eos_token_id=92542, - **asdict(generation_config), - ): - # Display robot response in chat message container - message_placeholder.markdown(postprocess(cur_response) + '▌') - message_placeholder.markdown(postprocess(cur_response)) - deepthink_message['content'] = cur_response - torch.cuda.empty_cache() - else: - st.markdown(postprocess(deepthink_message['content'])) + if st.session_state['inference_mode'] == 'Deep Thinking': + with cols[1]: + render_message(deepthink_message, idx, True) + with cols[0]: + render_message(message, idx, False) + else: + with cols[0]: + render_message(message, idx, False) + with cols[1]: + render_message(deepthink_message, idx, True) else: - st.markdown(postprocess(message['content'])) + if st.session_state['inference_mode'] == 'Deep Thinking': + if deepthink_message['content'] is not None: + st.markdown( + postprocess(deepthink_message['content'], + deepthink=True)) + else: + st.markdown(postprocess(message['content'])) + else: + if message['content'] is not None: + st.markdown(postprocess(message['content'])) + else: + st.markdown( + postprocess(deepthink_message['content'], + deepthink=True)) # Accept user input if prompt := st.chat_input('What is up?'): # Display user message in chat message container with st.chat_message('user', avatar=user_avator): - st.markdown(postprocess(prompt)) - real_prompt = combine_history(prompt) + st.markdown(postprocess(prompt, add_prefix=False)) + real_prompt = combine_history( + prompt, + deepthink=st.session_state['inference_mode'] == 'Deep Thinking') # Add user message to chat history - st.session_state.messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator}) - st.session_state.deepthink_messages.append({'role': 'user', 'content': prompt, 'avatar': user_avator}) + st.session_state.messages.append({ + 'role': 'user', + 'content': prompt, + 'avatar': user_avator + }) + st.session_state.deepthink_messages.append({ + 'role': 'user', + 'content': prompt, + 'avatar': user_avator + }) with st.chat_message('robot', avatar=robot_avator): - st.button('深度思考', key=f'deep_mode_{len(st.session_state.messages)}') + st.toggle('compare', + key=f'compare_{len(st.session_state.messages)}') message_placeholder = st.empty() for cur_response in generate_interactive( - model=model, - tokenizer=tokenizer, - prompt=real_prompt, - additional_eos_token_id=92542, - **asdict(generation_config), + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=92542, + **asdict(generation_config), ): # Display robot response in chat message container - message_placeholder.markdown(postprocess(cur_response) + '▌') - message_placeholder.markdown(postprocess(cur_response)) + message_placeholder.markdown( + postprocess(cur_response, + deepthink=st.session_state['inference_mode'] == + 'Deep Thinking') + '▌') + message_placeholder.markdown( + postprocess(cur_response, + deepthink=st.session_state['inference_mode'] == + 'Deep Thinking')) # Add robot response to chat history - st.session_state.messages.append( - { - 'role': 'robot', - 'content': cur_response, # pylint: disable=undefined-loop-variable - 'avatar': robot_avator, - } - ) - st.session_state.deepthink_messages.append( - { - 'role': 'robot', - 'content': None, - 'avatar': robot_avator, - } - ) + response, deepthink_response = ((None, cur_response) + if st.session_state['inference_mode'] + == 'Deep Thinking' else + (cur_response, None)) + st.session_state.messages.append({ + 'role': 'robot', + 'content': response, # pylint: disable=undefined-loop-variable + 'avatar': robot_avator, + }) + st.session_state.deepthink_messages.append({ + 'role': 'robot', + 'content': deepthink_response, + 'avatar': robot_avator, + }) torch.cuda.empty_cache()