From 6908fa5e16fa4ae348c31b04886dc263f8116549 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 3 Jul 2024 20:22:19 +0800 Subject: [PATCH] fix lint --- agent/pal_inference.py | 10 +-- chat/web_demo.py | 13 +-- long_context/README.md | 3 +- long_context/README_zh-CN.md | 3 +- long_context/doc_chat_demo.py | 145 +++++++++++++++++++++------------- model_cards/internlm2.5_7b.md | 17 ++-- model_cards/internlm2_1.8b.md | 33 ++++---- tests/test_hf_model.py | 14 ++-- tools/convert2llama.py | 8 +- 9 files changed, 140 insertions(+), 106 deletions(-) diff --git a/agent/pal_inference.py b/agent/pal_inference.py index ed55390..ceac40f 100644 --- a/agent/pal_inference.py +++ b/agent/pal_inference.py @@ -189,8 +189,8 @@ def generate_interactive( generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( # pylint: disable=W4902 - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=' + f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ' 'Please refer to the documentation for more information. ' '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)', UserWarning, @@ -199,8 +199,8 @@ def generate_interactive( if input_ids_seq_length >= generation_config.max_length: input_ids_string = 'input_ids' logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to' + f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider' ' increasing `max_new_tokens`.') # 2. Set generation parameters if not already defined @@ -510,7 +510,7 @@ def main(): interface.clear_history() f.flush() - print(f"{args.model}: Accuracy - {sum(scores) / len(scores)}") + print(f'{args.model}: Accuracy - {sum(scores) / len(scores)}') torch.cuda.empty_cache() diff --git a/chat/web_demo.py b/chat/web_demo.py index 4a67ebd..cc5f07c 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -77,7 +77,8 @@ def generate_interactive( 'max_length') is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - f"Using 'max_length''s default ({repr(generation_config.max_length)}) \ + f"Using 'max_length''s default \ + ({repr(generation_config.max_length)}) \ to control the generation length. " 'This behaviour is deprecated and will be removed from the \ config in v5 of Transformers -- we' @@ -102,7 +103,7 @@ def generate_interactive( if input_ids_seq_length >= generation_config.max_length: input_ids_string = 'input_ids' logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, " + f'Input length of {input_ids_string} is {input_ids_seq_length}, ' f"but 'max_length' is set to {generation_config.max_length}. " 'This can lead to unexpected behavior. You should consider' " increasing 'max_new_tokens'.") @@ -180,9 +181,9 @@ def on_btn_click(): @st.cache_resource def load_model(): - model = (AutoModelForCausalLM.from_pretrained('internlm/internlm2_5-7b-chat', - trust_remote_code=True).to( - torch.bfloat16).cuda()) + model = (AutoModelForCausalLM.from_pretrained( + 'internlm/internlm2_5-7b-chat', + trust_remote_code=True).to(torch.bfloat16).cuda()) tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2_5-7b-chat', trust_remote_code=True) return model, tokenizer @@ -216,7 +217,7 @@ def combine_history(prompt): meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, ' 'and harmless AI assistant developed by Shanghai ' 'AI Laboratory (上海人工智能实验室).') - total_prompt = f"<|im_start|>system\n{meta_instruction}<|im_end|>\n" + total_prompt = f'<|im_start|>system\n{meta_instruction}<|im_end|>\n' for message in messages: cur_content = message['content'] if message['role'] == 'user': diff --git a/long_context/README.md b/long_context/README.md index bda1d16..e0fd13f 100644 --- a/long_context/README.md +++ b/long_context/README.md @@ -35,6 +35,7 @@ Currently, we support PDF, TXT, and Markdown files, with more file types to be s ### Installation To get started, install the required packages: + ```bash pip install "fairy-doc[cpu]" pip install streamlit @@ -77,8 +78,6 @@ The effect is demonstrated in the video below. https://github.com/libowen2121/InternLM/assets/19970308/1d7f9b87-d458-4f24-9f7a-437a4da3fa6e - ## 🔜 Stay Tuned for More We are continuously enhancing our models to better understand and reason with extensive long inputs. Expect new features, improved performance, and expanded capabilities in upcoming updates! - diff --git a/long_context/README_zh-CN.md b/long_context/README_zh-CN.md index 3647c6d..cd5e38c 100644 --- a/long_context/README_zh-CN.md +++ b/long_context/README_zh-CN.md @@ -32,6 +32,7 @@ ### 安装 开始前,请安装所需的依赖: + ```bash pip install "fairy-doc[cpu]" pip install streamlit @@ -76,4 +77,4 @@ https://github.com/libowen2121/InternLM/assets/19970308/1d7f9b87-d458-4f24-9f7a- ## 🔜 敬请期待更多 -我们将不断优化和更新长文本模型,以提升其在长文本上的理解和分析能力。敬请关注! \ No newline at end of file +我们将不断优化和更新长文本模型,以提升其在长文本上的理解和分析能力。敬请关注! diff --git a/long_context/doc_chat_demo.py b/long_context/doc_chat_demo.py index 9a1d95b..659c850 100644 --- a/long_context/doc_chat_demo.py +++ b/long_context/doc_chat_demo.py @@ -1,12 +1,15 @@ +import argparse import logging from dataclasses import dataclass + import streamlit as st -from openai import OpenAI from magic_doc.docconv import DocConverter -import argparse +from openai import OpenAI # Set up logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + @dataclass class GenerationConfig: @@ -16,13 +19,14 @@ class GenerationConfig: temperature: float = 0.1 repetition_penalty: float = 1.005 + def generate( client, messages, generation_config, ): stream = client.chat.completions.create( - model=st.session_state["model_name"], + model=st.session_state['model_name'], messages=messages, stream=True, temperature=generation_config.temperature, @@ -32,113 +36,140 @@ def generate( ) return stream + def prepare_generation_config(): with st.sidebar: - max_tokens = st.number_input("Max Tokens", min_value=100, max_value=4096, value=1024) - top_p = st.number_input("Top P", 0.0, 1.0, 1.0, step=0.01) - temperature = st.number_input("Temperature", 0.0, 1.0, 0.05, step=0.01) - repetition_penalty = st.number_input("Repetition Penalty", 0.8, 1.2, 1.02, step=0.001, format="%0.3f") - st.button("Clear Chat History", on_click=on_btn_click) + max_tokens = st.number_input('Max Tokens', + min_value=100, + max_value=4096, + value=1024) + top_p = st.number_input('Top P', 0.0, 1.0, 1.0, step=0.01) + temperature = st.number_input('Temperature', 0.0, 1.0, 0.05, step=0.01) + repetition_penalty = st.number_input('Repetition Penalty', + 0.8, + 1.2, + 1.02, + step=0.001, + format='%0.3f') + st.button('Clear Chat History', on_click=on_btn_click) - generation_config = GenerationConfig( - max_tokens=max_tokens, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty - ) + generation_config = GenerationConfig(max_tokens=max_tokens, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty) return generation_config + def on_btn_click(): del st.session_state.messages st.session_state.file_content_found = False st.session_state.file_content_used = False + user_avator = 'assets/user.png' robot_avator = 'assets/robot.png' st.title('InternLM2.5 File Chat 📝') + def main(base_url): # Initialize the client for the model - client = OpenAI( - base_url=base_url, - timeout=12000 - ) + client = OpenAI(base_url=base_url, timeout=12000) # Get the model ID model_name = client.models.list().data[0].id - st.session_state["model_name"] = model_name + st.session_state['model_name'] = model_name # Get the generation config generation_config = prepare_generation_config() # Initialize session state - if "messages" not in st.session_state: + if 'messages' not in st.session_state: st.session_state.messages = [] - if "file_content_found" not in st.session_state: + if 'file_content_found' not in st.session_state: st.session_state.file_content_found = False st.session_state.file_content_used = False - st.session_state.file_name = "" + st.session_state.file_name = '' # Handle file upload if not st.session_state.file_content_found: - uploaded_file = st.file_uploader("Upload an article", type=("txt", "md", "pdf")) - file_content = "" + uploaded_file = st.file_uploader('Upload an article', + type=('txt', 'md', 'pdf')) + file_content = '' if uploaded_file is not None: - if uploaded_file.type == "application/pdf": - with open("uploaded_file.pdf", "wb") as f: + if uploaded_file.type == 'application/pdf': + with open('uploaded_file.pdf', 'wb') as f: f.write(uploaded_file.getbuffer()) converter = DocConverter(s3_config=None) - file_content, time_cost = converter.convert("uploaded_file.pdf", conv_timeout=300) - st.session_state.file_content_found = True # Reset flag when a new file is uploaded - st.session_state.file_content = file_content # Store the file content in session state - st.session_state.file_name = uploaded_file.name # Store the file name in session state + file_content, time_cost = converter.convert( + 'uploaded_file.pdf', conv_timeout=300) + # Reset flag when a new file is uploaded + st.session_state.file_content_found = True + # Store the file content in session state + st.session_state.file_content = file_content + # Store the file name in session state + st.session_state.file_name = uploaded_file.name else: - file_content = uploaded_file.read().decode("utf-8") - st.session_state.file_content_found = True # Reset flag when a new file is uploaded - st.session_state.file_content = file_content # Store the file content in session state - st.session_state.file_name = uploaded_file.name # Store the file name in session state + file_content = uploaded_file.read().decode('utf-8') + # Reset flag when a new file is uploaded + st.session_state.file_content_found = True + # Store the file content in session state + st.session_state.file_content = file_content + # Store the file name in session state + st.session_state.file_name = uploaded_file.name if st.session_state.file_content_found: - st.success(f"File '{st.session_state.file_name}' has been successfully uploaded!") + st.success(f"File '{st.session_state.file_name}' " + 'has been successfully uploaded!') # Display chat messages for message in st.session_state.messages: - with st.chat_message(message["role"], avatar=message.get("avatar")): - st.markdown(message["content"]) + with st.chat_message(message['role'], avatar=message.get('avatar')): + st.markdown(message['content']) # Handle user input and response generation if prompt := st.chat_input("What's up?"): - turn = {"role": "user", "content": prompt, "avatar": user_avator} - if st.session_state.file_content_found and not st.session_state.file_content_used: + turn = {'role': 'user', 'content': prompt, 'avatar': user_avator} + if (st.session_state.file_content_found + and not st.session_state.file_content_used): assert st.session_state.file_content is not None - merged_prompt = f"{st.session_state.file_content}\n\n{prompt}" - st.session_state.file_content_used = True # Set flag to indicate file content has been used - turn["merged_content"] = merged_prompt + merged_prompt = f'{st.session_state.file_content}\n\n{prompt}' + # Set flag to indicate file content has been used + st.session_state.file_content_used = True + turn['merged_content'] = merged_prompt st.session_state.messages.append(turn) - with st.chat_message("user", avatar=user_avator): + with st.chat_message('user', avatar=user_avator): st.markdown(prompt) - with st.chat_message("assistant", avatar=robot_avator): - messages = [ - { - "role": m["role"], - "content": m["merged_content"] if "merged_content" in m else m["content"], - } - for m in st.session_state.messages - ] + with st.chat_message('assistant', avatar=robot_avator): + messages = [{ + 'role': + m['role'], + 'content': + m['merged_content'] if 'merged_content' in m else m['content'], + } for m in st.session_state.messages] # Log messages to the terminal for m in messages: - logging.info(f"\n\n*** [{m['role']}] ***\n\n\t{m['content']}\n\n") + logging.info( + f"\n\n*** [{m['role']}] ***\n\n\t{m['content']}\n\n") stream = generate(client, messages, generation_config) response = st.write_stream(stream) - st.session_state.messages.append({"role": "assistant", "content": response, "avatar": robot_avator}) + st.session_state.messages.append({ + 'role': 'assistant', + 'content': response, + 'avatar': robot_avator + }) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run Streamlit app with OpenAI client.") - parser.add_argument("--base_url", type=str, required=True, help="Base URL for the OpenAI client") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Run Streamlit app with OpenAI client.') + parser.add_argument('--base_url', + type=str, + required=True, + help='Base URL for the OpenAI client') args = parser.parse_args() main(args.base_url) diff --git a/model_cards/internlm2.5_7b.md b/model_cards/internlm2.5_7b.md index 60974b3..35ddea6 100644 --- a/model_cards/internlm2.5_7b.md +++ b/model_cards/internlm2.5_7b.md @@ -42,16 +42,15 @@ We have evaluated InternLM2.5 on several important benchmarks using the open-sou ### Chat Model | Benchmark | InternLM2-Chat-7B | LLaMA-3-8B-Instruct | Yi-1.5-9B-Chat | GLM-4-9B-Chat | Qwen2-7B-Instruct | Gemma2-9B-IT | InternLM2.5-7B-Chat | Llama-3-70B-Instruct | -| ----------------- | ----------------- | ------------------- | -------------- | ------------- | ----------------- | ------------ | ------------------- |------------------- | -| MMLU(5-shot) | 62.3 | 68.4 | 71.0 | 71.4 | 70.8 | 70.9 | 72.8 | 80.5 | -| CMMLU(5-shot) | 62.4 | 53.3 | 74.5 | 74.5 | 80.9 | 60.3 | 78.0 | 70.1 | -| BBH(3-shot CoT) | 59.0 | 54.4 | 69.6 | 69.6 | 65.0 | 68.2\* | 71.6 | 80.5 | -| MATH(0-shot CoT) | 27.6 | 27.9 | 51.1 | 51.1 | 48.6 | 46.9 | 60.1 | 47.1 | -| GSM8K(0-shot CoT) | 72.5 | 72.9 | 80.1 | 85.3 | 82.9 | 88.9 | 86.0 | 92.8 | -| GPQA(0-shot) | 29.8 | 26.1 | 37.9 | 36.9 | 38.4 | 33.8 | 38.4 | 38.9 | - +| ----------------- | ----------------- | ------------------- | -------------- | ------------- | ----------------- | ------------ | ------------------- | -------------------- | +| MMLU(5-shot) | 62.3 | 68.4 | 71.0 | 71.4 | 70.8 | 70.9 | 72.8 | 80.5 | +| CMMLU(5-shot) | 62.4 | 53.3 | 74.5 | 74.5 | 80.9 | 60.3 | 78.0 | 70.1 | +| BBH(3-shot CoT) | 59.0 | 54.4 | 69.6 | 69.6 | 65.0 | 68.2\* | 71.6 | 80.5 | +| MATH(0-shot CoT) | 27.6 | 27.9 | 51.1 | 51.1 | 48.6 | 46.9 | 60.1 | 47.1 | +| GSM8K(0-shot CoT) | 72.5 | 72.9 | 80.1 | 85.3 | 82.9 | 88.9 | 86.0 | 92.8 | +| GPQA(0-shot) | 29.8 | 26.1 | 37.9 | 36.9 | 38.4 | 33.8 | 38.4 | 38.9 | - We use `ppl` for the MCQ evaluation on base model. - The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass). - The evaluation data may have numerical differences due to the version iteration of [OpenCompass](https://github.com/open-compass/opencompass), so please refer to the latest evaluation results of [OpenCompass](https://github.com/open-compass/opencompass). -- \* means the result is copied from the original paper. \ No newline at end of file +- \* means the result is copied from the original paper. diff --git a/model_cards/internlm2_1.8b.md b/model_cards/internlm2_1.8b.md index b811ca3..a92438a 100644 --- a/model_cards/internlm2_1.8b.md +++ b/model_cards/internlm2_1.8b.md @@ -5,7 +5,7 @@ InternLM2-1.8B is the 1.8 billion parameter version of the second generation InternLM series. In order to facilitate user use and research, InternLM2-1.8B has three versions of open-source models. They are: - InternLM2-1.8B: Foundation models with high quality and high adaptation flexibility, which serve as a good starting point for downstream deep adaptations. -- InternLM2-Chat-1.8B-SFT: Chat model after supervised fine-tuning (SFT) on InternLM2-1.8B. +- InternLM2-Chat-1.8B-SFT: Chat model after supervised fine-tuning (SFT) on InternLM2-1.8B. - InternLM2-Chat-1.8B: Further aligned on top of InternLM2-Chat-1.8B-SFT through online RLHF. InternLM2-Chat-1.8B exhibits better instruction following, chat experience, and function calling, which is recommended for downstream applications. The base model of InternLM2 has the following technical features: @@ -15,26 +15,25 @@ The base model of InternLM2 has the following technical features: ## Model Zoo -| Model | Transformers(HF) | ModelScope(HF) | OpenXLab(HF) | OpenXLab(Origin) | Release Date | -| -------------------------- | ------------------------------------------ | ---------------------------------------- | -------------------------------------- | ------------------------------------------ | ------------ | -| **InternLM2-1.8B** | [🤗internlm2-1.8b](https://huggingface.co/internlm/internlm2-1_8b) | [ internlm2-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b-original) | 2024-01-31 | -| **InternLM2-Chat-1.8B-SFT** | [🤗internlm2-chat-1.8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) | [ internlm2-chat-1.8b-sft](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft-original) | 2024-01-31 | -| **InternLM2-Chat-1.8B** | [🤗internlm2-chat-1.8b](https://huggingface.co/internlm/internlm2-chat-1_8b) | [ internlm2-chat-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-original) | 2024-02-19 | +| Model | Transformers(HF) | ModelScope(HF) | OpenXLab(HF) | OpenXLab(Origin) | Release Date | +| --------------------------- | ----------------------------------------- | ---------------------------------------- | -------------------------------------- | ------------------------------------------ | ------------ | +| **InternLM2-1.8B** | [🤗internlm2-1.8b](https://huggingface.co/internlm/internlm2-1_8b) | [ internlm2-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b-original) | 2024-01-31 | +| **InternLM2-Chat-1.8B-SFT** | [🤗internlm2-chat-1.8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) | [ internlm2-chat-1.8b-sft](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft-original) | 2024-01-31 | +| **InternLM2-Chat-1.8B** | [🤗internlm2-chat-1.8b](https://huggingface.co/internlm/internlm2-chat-1_8b) | [ internlm2-chat-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-original) | 2024-02-19 | ## Performance Evaluation We have evaluated InternLM2 on several important benchmarks using the open-source evaluation tool [OpenCompass](https://github.com/open-compass/opencompass). Some of the evaluation results are shown in the table below. You are welcome to visit the [OpenCompass Leaderboard](https://opencompass.org.cn/rank) for more evaluation results. -| Dataset\Models | InternLM2-1.8B | InternLM2-Chat-1.8B-SFT | InternLM2-Chat-1.8B | InternLM2-7B | InternLM2-Chat-7B | -| :---: | :---: | :---: | :---: | :---: | :---: | -| MMLU | 46.9 | 47.1 | 44.1 | 65.8 | 63.7 | -| AGIEval | 33.4 | 38.8 | 34.6 | 49.9 | 47.2 | -| BBH | 37.5 | 35.2 | 34.3 | 65.0 | 61.2 | -| GSM8K | 31.2 | 39.7 | 34.3 | 70.8 | 70.7 | -| MATH | 5.6 | 11.8 | 10.7 | 20.2 | 23.0 | -| HumanEval | 25.0 | 32.9 | 29.3 | 43.3 | 59.8 | -| MBPP(Sanitized) | 22.2 | 23.2 | 27.0 | 51.8 | 51.4 | +| Dataset\\Models | InternLM2-1.8B | InternLM2-Chat-1.8B-SFT | InternLM2-Chat-1.8B | InternLM2-7B | InternLM2-Chat-7B | +| :-------------: | :------------: | :---------------------: | :-----------------: | :----------: | :---------------: | +| MMLU | 46.9 | 47.1 | 44.1 | 65.8 | 63.7 | +| AGIEval | 33.4 | 38.8 | 34.6 | 49.9 | 47.2 | +| BBH | 37.5 | 35.2 | 34.3 | 65.0 | 61.2 | +| GSM8K | 31.2 | 39.7 | 34.3 | 70.8 | 70.7 | +| MATH | 5.6 | 11.8 | 10.7 | 20.2 | 23.0 | +| HumanEval | 25.0 | 32.9 | 29.3 | 43.3 | 59.8 | +| MBPP(Sanitized) | 22.2 | 23.2 | 27.0 | 51.8 | 51.4 | - -- The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass). +- The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass). - The evaluation data may have numerical differences due to the version iteration of [OpenCompass](https://github.com/open-compass/opencompass), so please refer to the latest evaluation results of [OpenCompass](https://github.com/open-compass/opencompass). diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index 8a47e01..8ce810d 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -21,10 +21,10 @@ class TestChat: @pytest.mark.parametrize( 'model_name', [ - 'internlm/internlm2_5-7b-chat', - 'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft', - 'internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-sft', - 'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft' + 'internlm/internlm2_5-7b-chat', 'internlm/internlm2-chat-7b', + 'internlm/internlm2-chat-7b-sft', 'internlm/internlm2-chat-20b', + 'internlm/internlm2-chat-20b-sft', 'internlm/internlm2-chat-1_8b', + 'internlm/internlm2-chat-1_8b-sft' ], ) @pytest.mark.parametrize( @@ -128,8 +128,10 @@ class TestMath: @pytest.mark.parametrize( 'model_name', - ['internlm/internlm2-math-7b', 'internlm/internlm2-math-base-7b', - 'internlm/internlm2-math-plus-1_8b', 'internlm/internlm2-math-plus-7b' + [ + 'internlm/internlm2-math-7b', 'internlm/internlm2-math-base-7b', + 'internlm/internlm2-math-plus-1_8b', + 'internlm/internlm2-math-plus-7b' ], ) @pytest.mark.parametrize( diff --git a/tools/convert2llama.py b/tools/convert2llama.py index de65b58..f9d1d94 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -59,8 +59,8 @@ def convert(src, tgt): assert not config.bias, 'Cannot convert InternLM Model with bias to LLaMA.' head_dim = config.hidden_size // config.num_attention_heads - num_key_value_groups = config.num_attention_heads \ - // config.num_key_value_heads + num_key_value_groups = \ + config.num_attention_heads // config.num_key_value_heads # load index json file index_file = 'pytorch_model.bin.index.json' @@ -140,7 +140,9 @@ def convert(src, tgt): print(f'Saving to {os.path.join(tgt, filename)}...', flush=True) if filename.endswith('.safetensors'): from safetensors.torch import save_file - save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"}) + save_file(llama_states, + os.path.join(tgt, filename), + metadata={'format': 'pt'}) else: torch.save(llama_states, os.path.join(tgt, filename)) del states