pull/752/head
RangiLyu 2024-07-03 20:22:19 +08:00
parent 5e594c1699
commit 6908fa5e16
9 changed files with 140 additions and 106 deletions

View File

@ -189,8 +189,8 @@ def generate_interactive(
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
'Please refer to the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
UserWarning,
@ -199,8 +199,8 @@ def generate_interactive(
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
' increasing `max_new_tokens`.')
# 2. Set generation parameters if not already defined
@ -510,7 +510,7 @@ def main():
interface.clear_history()
f.flush()
print(f"{args.model}: Accuracy - {sum(scores) / len(scores)}")
print(f'{args.model}: Accuracy - {sum(scores) / len(scores)}')
torch.cuda.empty_cache()

View File

@ -77,7 +77,8 @@ def generate_interactive(
'max_length') is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using 'max_length''s default ({repr(generation_config.max_length)}) \
f"Using 'max_length''s default \
({repr(generation_config.max_length)}) \
to control the generation length. "
'This behaviour is deprecated and will be removed from the \
config in v5 of Transformers -- we'
@ -102,7 +103,7 @@ def generate_interactive(
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, "
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
f"but 'max_length' is set to {generation_config.max_length}. "
'This can lead to unexpected behavior. You should consider'
" increasing 'max_new_tokens'.")
@ -180,9 +181,9 @@ def on_btn_click():
@st.cache_resource
def load_model():
model = (AutoModelForCausalLM.from_pretrained('internlm/internlm2_5-7b-chat',
trust_remote_code=True).to(
torch.bfloat16).cuda())
model = (AutoModelForCausalLM.from_pretrained(
'internlm/internlm2_5-7b-chat',
trust_remote_code=True).to(torch.bfloat16).cuda())
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2_5-7b-chat',
trust_remote_code=True)
return model, tokenizer
@ -216,7 +217,7 @@ def combine_history(prompt):
meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '
'and harmless AI assistant developed by Shanghai '
'AI Laboratory (上海人工智能实验室).')
total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
for message in messages:
cur_content = message['content']
if message['role'] == 'user':

View File

@ -35,6 +35,7 @@ Currently, we support PDF, TXT, and Markdown files, with more file types to be s
### Installation
To get started, install the required packages:
```bash
pip install "fairy-doc[cpu]"
pip install streamlit
@ -77,8 +78,6 @@ The effect is demonstrated in the video below.
https://github.com/libowen2121/InternLM/assets/19970308/1d7f9b87-d458-4f24-9f7a-437a4da3fa6e
## 🔜 Stay Tuned for More
We are continuously enhancing our models to better understand and reason with extensive long inputs. Expect new features, improved performance, and expanded capabilities in upcoming updates!

View File

@ -32,6 +32,7 @@
### 安装
开始前,请安装所需的依赖:
```bash
pip install "fairy-doc[cpu]"
pip install streamlit
@ -76,4 +77,4 @@ https://github.com/libowen2121/InternLM/assets/19970308/1d7f9b87-d458-4f24-9f7a-
## 🔜 敬请期待更多
我们将不断优化和更新长文本模型,以提升其在长文本上的理解和分析能力。敬请关注!
我们将不断优化和更新长文本模型,以提升其在长文本上的理解和分析能力。敬请关注!

View File

@ -1,12 +1,15 @@
import argparse
import logging
from dataclasses import dataclass
import streamlit as st
from openai import OpenAI
from magic_doc.docconv import DocConverter
import argparse
from openai import OpenAI
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
@dataclass
class GenerationConfig:
@ -16,13 +19,14 @@ class GenerationConfig:
temperature: float = 0.1
repetition_penalty: float = 1.005
def generate(
client,
messages,
generation_config,
):
stream = client.chat.completions.create(
model=st.session_state["model_name"],
model=st.session_state['model_name'],
messages=messages,
stream=True,
temperature=generation_config.temperature,
@ -32,113 +36,140 @@ def generate(
)
return stream
def prepare_generation_config():
with st.sidebar:
max_tokens = st.number_input("Max Tokens", min_value=100, max_value=4096, value=1024)
top_p = st.number_input("Top P", 0.0, 1.0, 1.0, step=0.01)
temperature = st.number_input("Temperature", 0.0, 1.0, 0.05, step=0.01)
repetition_penalty = st.number_input("Repetition Penalty", 0.8, 1.2, 1.02, step=0.001, format="%0.3f")
st.button("Clear Chat History", on_click=on_btn_click)
max_tokens = st.number_input('Max Tokens',
min_value=100,
max_value=4096,
value=1024)
top_p = st.number_input('Top P', 0.0, 1.0, 1.0, step=0.01)
temperature = st.number_input('Temperature', 0.0, 1.0, 0.05, step=0.01)
repetition_penalty = st.number_input('Repetition Penalty',
0.8,
1.2,
1.02,
step=0.001,
format='%0.3f')
st.button('Clear Chat History', on_click=on_btn_click)
generation_config = GenerationConfig(
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty
)
generation_config = GenerationConfig(max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
return generation_config
def on_btn_click():
del st.session_state.messages
st.session_state.file_content_found = False
st.session_state.file_content_used = False
user_avator = 'assets/user.png'
robot_avator = 'assets/robot.png'
st.title('InternLM2.5 File Chat 📝')
def main(base_url):
# Initialize the client for the model
client = OpenAI(
base_url=base_url,
timeout=12000
)
client = OpenAI(base_url=base_url, timeout=12000)
# Get the model ID
model_name = client.models.list().data[0].id
st.session_state["model_name"] = model_name
st.session_state['model_name'] = model_name
# Get the generation config
generation_config = prepare_generation_config()
# Initialize session state
if "messages" not in st.session_state:
if 'messages' not in st.session_state:
st.session_state.messages = []
if "file_content_found" not in st.session_state:
if 'file_content_found' not in st.session_state:
st.session_state.file_content_found = False
st.session_state.file_content_used = False
st.session_state.file_name = ""
st.session_state.file_name = ''
# Handle file upload
if not st.session_state.file_content_found:
uploaded_file = st.file_uploader("Upload an article", type=("txt", "md", "pdf"))
file_content = ""
uploaded_file = st.file_uploader('Upload an article',
type=('txt', 'md', 'pdf'))
file_content = ''
if uploaded_file is not None:
if uploaded_file.type == "application/pdf":
with open("uploaded_file.pdf", "wb") as f:
if uploaded_file.type == 'application/pdf':
with open('uploaded_file.pdf', 'wb') as f:
f.write(uploaded_file.getbuffer())
converter = DocConverter(s3_config=None)
file_content, time_cost = converter.convert("uploaded_file.pdf", conv_timeout=300)
st.session_state.file_content_found = True # Reset flag when a new file is uploaded
st.session_state.file_content = file_content # Store the file content in session state
st.session_state.file_name = uploaded_file.name # Store the file name in session state
file_content, time_cost = converter.convert(
'uploaded_file.pdf', conv_timeout=300)
# Reset flag when a new file is uploaded
st.session_state.file_content_found = True
# Store the file content in session state
st.session_state.file_content = file_content
# Store the file name in session state
st.session_state.file_name = uploaded_file.name
else:
file_content = uploaded_file.read().decode("utf-8")
st.session_state.file_content_found = True # Reset flag when a new file is uploaded
st.session_state.file_content = file_content # Store the file content in session state
st.session_state.file_name = uploaded_file.name # Store the file name in session state
file_content = uploaded_file.read().decode('utf-8')
# Reset flag when a new file is uploaded
st.session_state.file_content_found = True
# Store the file content in session state
st.session_state.file_content = file_content
# Store the file name in session state
st.session_state.file_name = uploaded_file.name
if st.session_state.file_content_found:
st.success(f"File '{st.session_state.file_name}' has been successfully uploaded!")
st.success(f"File '{st.session_state.file_name}' "
'has been successfully uploaded!')
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=message.get("avatar")):
st.markdown(message["content"])
with st.chat_message(message['role'], avatar=message.get('avatar')):
st.markdown(message['content'])
# Handle user input and response generation
if prompt := st.chat_input("What's up?"):
turn = {"role": "user", "content": prompt, "avatar": user_avator}
if st.session_state.file_content_found and not st.session_state.file_content_used:
turn = {'role': 'user', 'content': prompt, 'avatar': user_avator}
if (st.session_state.file_content_found
and not st.session_state.file_content_used):
assert st.session_state.file_content is not None
merged_prompt = f"{st.session_state.file_content}\n\n{prompt}"
st.session_state.file_content_used = True # Set flag to indicate file content has been used
turn["merged_content"] = merged_prompt
merged_prompt = f'{st.session_state.file_content}\n\n{prompt}'
# Set flag to indicate file content has been used
st.session_state.file_content_used = True
turn['merged_content'] = merged_prompt
st.session_state.messages.append(turn)
with st.chat_message("user", avatar=user_avator):
with st.chat_message('user', avatar=user_avator):
st.markdown(prompt)
with st.chat_message("assistant", avatar=robot_avator):
messages = [
{
"role": m["role"],
"content": m["merged_content"] if "merged_content" in m else m["content"],
}
for m in st.session_state.messages
]
with st.chat_message('assistant', avatar=robot_avator):
messages = [{
'role':
m['role'],
'content':
m['merged_content'] if 'merged_content' in m else m['content'],
} for m in st.session_state.messages]
# Log messages to the terminal
for m in messages:
logging.info(f"\n\n*** [{m['role']}] ***\n\n\t{m['content']}\n\n")
logging.info(
f"\n\n*** [{m['role']}] ***\n\n\t{m['content']}\n\n")
stream = generate(client, messages, generation_config)
response = st.write_stream(stream)
st.session_state.messages.append({"role": "assistant", "content": response, "avatar": robot_avator})
st.session_state.messages.append({
'role': 'assistant',
'content': response,
'avatar': robot_avator
})
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Streamlit app with OpenAI client.")
parser.add_argument("--base_url", type=str, required=True, help="Base URL for the OpenAI client")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Run Streamlit app with OpenAI client.')
parser.add_argument('--base_url',
type=str,
required=True,
help='Base URL for the OpenAI client')
args = parser.parse_args()
main(args.base_url)

View File

@ -42,16 +42,15 @@ We have evaluated InternLM2.5 on several important benchmarks using the open-sou
### Chat Model
| Benchmark | InternLM2-Chat-7B | LLaMA-3-8B-Instruct | Yi-1.5-9B-Chat | GLM-4-9B-Chat | Qwen2-7B-Instruct | Gemma2-9B-IT | InternLM2.5-7B-Chat | Llama-3-70B-Instruct |
| ----------------- | ----------------- | ------------------- | -------------- | ------------- | ----------------- | ------------ | ------------------- |------------------- |
| MMLU(5-shot) | 62.3 | 68.4 | 71.0 | 71.4 | 70.8 | 70.9 | 72.8 | 80.5 |
| CMMLU(5-shot) | 62.4 | 53.3 | 74.5 | 74.5 | 80.9 | 60.3 | 78.0 | 70.1 |
| BBH(3-shot CoT) | 59.0 | 54.4 | 69.6 | 69.6 | 65.0 | 68.2\* | 71.6 | 80.5 |
| MATH(0-shot CoT) | 27.6 | 27.9 | 51.1 | 51.1 | 48.6 | 46.9 | 60.1 | 47.1 |
| GSM8K(0-shot CoT) | 72.5 | 72.9 | 80.1 | 85.3 | 82.9 | 88.9 | 86.0 | 92.8 |
| GPQA(0-shot) | 29.8 | 26.1 | 37.9 | 36.9 | 38.4 | 33.8 | 38.4 | 38.9 |
| ----------------- | ----------------- | ------------------- | -------------- | ------------- | ----------------- | ------------ | ------------------- | -------------------- |
| MMLU(5-shot) | 62.3 | 68.4 | 71.0 | 71.4 | 70.8 | 70.9 | 72.8 | 80.5 |
| CMMLU(5-shot) | 62.4 | 53.3 | 74.5 | 74.5 | 80.9 | 60.3 | 78.0 | 70.1 |
| BBH(3-shot CoT) | 59.0 | 54.4 | 69.6 | 69.6 | 65.0 | 68.2\* | 71.6 | 80.5 |
| MATH(0-shot CoT) | 27.6 | 27.9 | 51.1 | 51.1 | 48.6 | 46.9 | 60.1 | 47.1 |
| GSM8K(0-shot CoT) | 72.5 | 72.9 | 80.1 | 85.3 | 82.9 | 88.9 | 86.0 | 92.8 |
| GPQA(0-shot) | 29.8 | 26.1 | 37.9 | 36.9 | 38.4 | 33.8 | 38.4 | 38.9 |
- We use `ppl` for the MCQ evaluation on base model.
- The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass).
- The evaluation data may have numerical differences due to the version iteration of [OpenCompass](https://github.com/open-compass/opencompass), so please refer to the latest evaluation results of [OpenCompass](https://github.com/open-compass/opencompass).
- \* means the result is copied from the original paper.
- \* means the result is copied from the original paper.

View File

@ -5,7 +5,7 @@
InternLM2-1.8B is the 1.8 billion parameter version of the second generation InternLM series. In order to facilitate user use and research, InternLM2-1.8B has three versions of open-source models. They are:
- InternLM2-1.8B: Foundation models with high quality and high adaptation flexibility, which serve as a good starting point for downstream deep adaptations.
- InternLM2-Chat-1.8B-SFT: Chat model after supervised fine-tuning (SFT) on InternLM2-1.8B.
- InternLM2-Chat-1.8B-SFT: Chat model after supervised fine-tuning (SFT) on InternLM2-1.8B.
- InternLM2-Chat-1.8B: Further aligned on top of InternLM2-Chat-1.8B-SFT through online RLHF. InternLM2-Chat-1.8B exhibits better instruction following, chat experience, and function calling, which is recommended for downstream applications.
The base model of InternLM2 has the following technical features:
@ -15,26 +15,25 @@ The base model of InternLM2 has the following technical features:
## Model Zoo
| Model | Transformers(HF) | ModelScope(HF) | OpenXLab(HF) | OpenXLab(Origin) | Release Date |
| -------------------------- | ------------------------------------------ | ---------------------------------------- | -------------------------------------- | ------------------------------------------ | ------------ |
| **InternLM2-1.8B** | [🤗internlm2-1.8b](https://huggingface.co/internlm/internlm2-1_8b) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b-original) | 2024-01-31 |
| **InternLM2-Chat-1.8B-SFT** | [🤗internlm2-chat-1.8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-chat-1.8b-sft](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft-original) | 2024-01-31 |
| **InternLM2-Chat-1.8B** | [🤗internlm2-chat-1.8b](https://huggingface.co/internlm/internlm2-chat-1_8b) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-chat-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-original) | 2024-02-19 |
| Model | Transformers(HF) | ModelScope(HF) | OpenXLab(HF) | OpenXLab(Origin) | Release Date |
| --------------------------- | ----------------------------------------- | ---------------------------------------- | -------------------------------------- | ------------------------------------------ | ------------ |
| **InternLM2-1.8B** | [🤗internlm2-1.8b](https://huggingface.co/internlm/internlm2-1_8b) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-base-1.8b-original) | 2024-01-31 |
| **InternLM2-Chat-1.8B-SFT** | [🤗internlm2-chat-1.8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-chat-1.8b-sft](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-sft-original) | 2024-01-31 |
| **InternLM2-Chat-1.8B** | [🤗internlm2-chat-1.8b](https://huggingface.co/internlm/internlm2-chat-1_8b) | [<img src="./assets/modelscope_logo.png" width="20px" /> internlm2-chat-1.8b](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/summary) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/OpenLMLab/internlm2-chat-1.8b-original) | 2024-02-19 |
## Performance Evaluation
We have evaluated InternLM2 on several important benchmarks using the open-source evaluation tool [OpenCompass](https://github.com/open-compass/opencompass). Some of the evaluation results are shown in the table below. You are welcome to visit the [OpenCompass Leaderboard](https://opencompass.org.cn/rank) for more evaluation results.
| Dataset\Models | InternLM2-1.8B | InternLM2-Chat-1.8B-SFT | InternLM2-Chat-1.8B | InternLM2-7B | InternLM2-Chat-7B |
| :---: | :---: | :---: | :---: | :---: | :---: |
| MMLU | 46.9 | 47.1 | 44.1 | 65.8 | 63.7 |
| AGIEval | 33.4 | 38.8 | 34.6 | 49.9 | 47.2 |
| BBH | 37.5 | 35.2 | 34.3 | 65.0 | 61.2 |
| GSM8K | 31.2 | 39.7 | 34.3 | 70.8 | 70.7 |
| MATH | 5.6 | 11.8 | 10.7 | 20.2 | 23.0 |
| HumanEval | 25.0 | 32.9 | 29.3 | 43.3 | 59.8 |
| MBPP(Sanitized) | 22.2 | 23.2 | 27.0 | 51.8 | 51.4 |
| Dataset\\Models | InternLM2-1.8B | InternLM2-Chat-1.8B-SFT | InternLM2-Chat-1.8B | InternLM2-7B | InternLM2-Chat-7B |
| :-------------: | :------------: | :---------------------: | :-----------------: | :----------: | :---------------: |
| MMLU | 46.9 | 47.1 | 44.1 | 65.8 | 63.7 |
| AGIEval | 33.4 | 38.8 | 34.6 | 49.9 | 47.2 |
| BBH | 37.5 | 35.2 | 34.3 | 65.0 | 61.2 |
| GSM8K | 31.2 | 39.7 | 34.3 | 70.8 | 70.7 |
| MATH | 5.6 | 11.8 | 10.7 | 20.2 | 23.0 |
| HumanEval | 25.0 | 32.9 | 29.3 | 43.3 | 59.8 |
| MBPP(Sanitized) | 22.2 | 23.2 | 27.0 | 51.8 | 51.4 |
- The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass).
- The evaluation results were obtained from [OpenCompass](https://github.com/open-compass/opencompass) , and evaluation configuration can be found in the configuration files provided by [OpenCompass](https://github.com/open-compass/opencompass).
- The evaluation data may have numerical differences due to the version iteration of [OpenCompass](https://github.com/open-compass/opencompass), so please refer to the latest evaluation results of [OpenCompass](https://github.com/open-compass/opencompass).

View File

@ -21,10 +21,10 @@ class TestChat:
@pytest.mark.parametrize(
'model_name',
[
'internlm/internlm2_5-7b-chat',
'internlm/internlm2-chat-7b', 'internlm/internlm2-chat-7b-sft',
'internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-sft',
'internlm/internlm2-chat-1_8b', 'internlm/internlm2-chat-1_8b-sft'
'internlm/internlm2_5-7b-chat', 'internlm/internlm2-chat-7b',
'internlm/internlm2-chat-7b-sft', 'internlm/internlm2-chat-20b',
'internlm/internlm2-chat-20b-sft', 'internlm/internlm2-chat-1_8b',
'internlm/internlm2-chat-1_8b-sft'
],
)
@pytest.mark.parametrize(
@ -128,8 +128,10 @@ class TestMath:
@pytest.mark.parametrize(
'model_name',
['internlm/internlm2-math-7b', 'internlm/internlm2-math-base-7b',
'internlm/internlm2-math-plus-1_8b', 'internlm/internlm2-math-plus-7b'
[
'internlm/internlm2-math-7b', 'internlm/internlm2-math-base-7b',
'internlm/internlm2-math-plus-1_8b',
'internlm/internlm2-math-plus-7b'
],
)
@pytest.mark.parametrize(

View File

@ -59,8 +59,8 @@ def convert(src, tgt):
assert not config.bias, 'Cannot convert InternLM Model with bias to LLaMA.'
head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads \
// config.num_key_value_heads
num_key_value_groups = \
config.num_attention_heads // config.num_key_value_heads
# load index json file
index_file = 'pytorch_model.bin.index.json'
@ -140,7 +140,9 @@ def convert(src, tgt):
print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
if filename.endswith('.safetensors'):
from safetensors.torch import save_file
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
save_file(llama_states,
os.path.join(tgt, filename),
metadata={'format': 'pt'})
else:
torch.save(llama_states, os.path.join(tgt, filename))
del states