mirror of https://github.com/InternLM/InternLM
feat(tools): add passkey retrieval test
parent
e611817442
commit
d2cc5f61e3
|
@ -1,4 +1,4 @@
|
||||||
本目录提供辅助模型训练的一些工具,文件结构如下所示:
|
本目录提供辅助模型训练和推理的一些工具,文件结构如下所示:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
├── transformers # 适配hugging face的transformers的一些工具
|
├── transformers # 适配hugging face的transformers的一些工具
|
||||||
|
@ -6,6 +6,7 @@
|
||||||
│ ├── modeling_internlm.py # model适配工具
|
│ ├── modeling_internlm.py # model适配工具
|
||||||
│ ├── tokenization_internlm.py # tokenizer适配工具
|
│ ├── tokenization_internlm.py # tokenizer适配工具
|
||||||
│ └── convert2hf.py # 模型适配hugging face工具
|
│ └── convert2hf.py # 模型适配hugging face工具
|
||||||
|
├── passkey_retrieval.py # 长文本检索测试工具
|
||||||
└── tokenizer.py # 将原始数据转换成bin和meta文件的工具
|
└── tokenizer.py # 将原始数据转换成bin和meta文件的工具
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -109,3 +110,24 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
|
||||||
| -------- | -------------------- |
|
| -------- | -------------------- |
|
||||||
| w/o tool | 34.5 |
|
| w/o tool | 34.5 |
|
||||||
| w tool | 39.2 |
|
| w tool | 39.2 |
|
||||||
|
|
||||||
|
# passkey_retrieval.py
|
||||||
|
用于测试模型输入不同长度文本时,提取细节的能力。测试方法来自 [这篇论文](https://arxiv.org/pdf/2305.16300.pdf)。使用方法:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
|
||||||
|
|
||||||
|
# 可选参数:
|
||||||
|
# --max-tokens <max_token> 最大输入文本长度(默认:4096)。
|
||||||
|
# --interval <interval> 每隔多大长度测试一轮(默认:1024)。
|
||||||
|
# --num_tests <num_tests> 每轮测多少次推理(默认:20)。
|
||||||
|
```
|
||||||
|
|
||||||
|
以下是使用示例:
|
||||||
|
```bash
|
||||||
|
python3 tools/passkey_retrieval.py
|
||||||
|
```
|
||||||
|
输出是不同 token 长度下检索 passkey 的精度。
|
||||||
|
```bash
|
||||||
|
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
|
||||||
|
```
|
|
@ -6,6 +6,7 @@ This directory provide some tools for model training with the following file str
|
||||||
│ ├── modeling_internlm.py # tools for adapting model
|
│ ├── modeling_internlm.py # tools for adapting model
|
||||||
│ └── tokenization_internlm.py # tools for adapting tokenizer
|
│ └── tokenization_internlm.py # tools for adapting tokenizer
|
||||||
│ └── convert2hf.py # tools for adapting models to Hugging Face's format
|
│ └── convert2hf.py # tools for adapting models to Hugging Face's format
|
||||||
|
├── passkey_retrieval.py # tools for testing handle long context
|
||||||
└── tokenizer.py # tools for generating `bin` and `meta` file for raw data
|
└── tokenizer.py # tools for generating `bin` and `meta` file for raw data
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -107,3 +108,24 @@ InternLM performance in the GSM8K dataset with and without tools:
|
||||||
| -------- | -------------------- |
|
| -------- | -------------------- |
|
||||||
| w/o tool | 34.5 |
|
| w/o tool | 34.5 |
|
||||||
| w tool | 39.2 |
|
| w tool | 39.2 |
|
||||||
|
|
||||||
|
# passkey_retrieval.py
|
||||||
|
Test the ability to extract details when inputting long text. This test method comes from [this paper](https://arxiv.org/pdf/2305.16300.pdf).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
|
||||||
|
|
||||||
|
# Optional parameters:
|
||||||
|
# --max-tokens <max_token> Maximum input text length (default: 4096).
|
||||||
|
# --interval <interval> The length of the test (default: 1024).
|
||||||
|
# --num_tests <num_tests> How many times to test inference per round (default: 20).
|
||||||
|
```
|
||||||
|
|
||||||
|
Below is an example of usage:
|
||||||
|
```bash
|
||||||
|
python3 tools/passkey_retrieval.py
|
||||||
|
```
|
||||||
|
The output is the accuracy of retrieving passkey under different token lengths.
|
||||||
|
```bash
|
||||||
|
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from numpy import random
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def parse_config():
|
||||||
|
parser = argparse.ArgumentParser(description='arg parser')
|
||||||
|
parser.add_argument('--max_tokens', type=int, default=4000, help='maximum token length for evaluation')
|
||||||
|
parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
|
||||||
|
parser.add_argument('--num_tests', type=int, default=20, help='number of repeat testing for each length')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
|
||||||
|
def generate_prompt_landmark(n_garbage=60000, seed=666):
|
||||||
|
"""Generates a text file and inserts an passkey at a random position."""
|
||||||
|
rnd_state = random.get_state()
|
||||||
|
random.seed(seed)
|
||||||
|
n_garbage_prefix = random.randint(0, n_garbage)
|
||||||
|
n_garbage_suffix = n_garbage - n_garbage_prefix
|
||||||
|
|
||||||
|
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
|
||||||
|
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
|
||||||
|
garbage_inf = " ".join([garbage] * 5000)
|
||||||
|
assert len(garbage_inf) >= n_garbage
|
||||||
|
garbage_prefix = garbage_inf[:n_garbage_prefix]
|
||||||
|
garbage_suffix = garbage_inf[:n_garbage_suffix]
|
||||||
|
pass_key = random.randint(1, 50000)
|
||||||
|
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
|
||||||
|
final_question = "What is the pass key? The pass key is"
|
||||||
|
lines = [
|
||||||
|
task_description,
|
||||||
|
garbage_prefix,
|
||||||
|
information_line,
|
||||||
|
garbage_suffix,
|
||||||
|
final_question,
|
||||||
|
]
|
||||||
|
random.set_state(rnd_state)
|
||||||
|
return "\n".join(lines), str(pass_key)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Load model and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True, device_map="auto")
|
||||||
|
|
||||||
|
total_test_points = args.max_tokens // args.interval
|
||||||
|
all_accuries = {}
|
||||||
|
for i in range(total_test_points):
|
||||||
|
# This is a rough ratio to control the number of texts and tokens
|
||||||
|
n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024)
|
||||||
|
passed_tests = 0
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for j in range(args.num_tests):
|
||||||
|
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
|
||||||
|
response, _ = model.chat(tokenizer, prompt, history=[])
|
||||||
|
if pass_key in response:
|
||||||
|
passed_tests += 1
|
||||||
|
total_tokens += len(tokenizer(prompt).input_ids)
|
||||||
|
avg_tokens = total_tokens//args.num_tests
|
||||||
|
accuracy = passed_tests/args.num_tests
|
||||||
|
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
|
||||||
|
all_accuries[str(avg_tokens)] = accuracy
|
||||||
|
print("accuries over tokens", all_accuries)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_config()
|
||||||
|
main(args)
|
Loading…
Reference in New Issue