feat(tools): add passkey retrieval test

pull/426/head
tpoisonooo 2023-10-19 16:18:42 +08:00
parent e611817442
commit d2cc5f61e3
3 changed files with 118 additions and 2 deletions

View File

@ -1,4 +1,4 @@
本目录提供辅助模型训练的一些工具,文件结构如下所示:
本目录提供辅助模型训练和推理的一些工具,文件结构如下所示:
```bash
├── transformers # 适配hugging face的transformers的一些工具
@ -6,6 +6,7 @@
│ ├── modeling_internlm.py # model适配工具
│ ├── tokenization_internlm.py # tokenizer适配工具
│ └── convert2hf.py # 模型适配hugging face工具
├── passkey_retrieval.py # 长文本检索测试工具
└── tokenizer.py # 将原始数据转换成bin和meta文件的工具
```
@ -109,3 +110,24 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# passkey_retrieval.py
用于测试模型输入不同长度文本时,提取细节的能力。测试方法来自 [这篇论文](https://arxiv.org/pdf/2305.16300.pdf)。使用方法:
```bash
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
# 可选参数:
# --max-tokens <max_token> 最大输入文本长度默认4096
# --interval <interval> 每隔多大长度测试一轮默认1024
# --num_tests <num_tests> 每轮测多少次推理默认20
```
以下是使用示例:
```bash
python3 tools/passkey_retrieval.py
```
输出是不同 token 长度下检索 passkey 的精度。
```bash
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
```

View File

@ -6,6 +6,7 @@ This directory provide some tools for model training with the following file str
│ ├── modeling_internlm.py # tools for adapting model
│ └── tokenization_internlm.py # tools for adapting tokenizer
│ └── convert2hf.py # tools for adapting models to Hugging Face's format
├── passkey_retrieval.py # tools for testing handle long context
└── tokenizer.py # tools for generating `bin` and `meta` file for raw data
```
@ -107,3 +108,24 @@ InternLM performance in the GSM8K dataset with and without tools:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# passkey_retrieval.py
Test the ability to extract details when inputting long text. This test method comes from [this paper](https://arxiv.org/pdf/2305.16300.pdf).
```bash
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
# Optional parameters:
# --max-tokens <max_token> Maximum input text length (default: 4096).
# --interval <interval> The length of the test (default: 1024).
# --num_tests <num_tests> How many times to test inference per round (default: 20).
```
Below is an example of usage:
```bash
python3 tools/passkey_retrieval.py
```
The output is the accuracy of retrieving passkey under different token lengths.
```bash
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
```

View File

@ -0,0 +1,72 @@
import argparse
import random
from numpy import random
from transformers import AutoTokenizer, AutoModelForCausalLM
import pdb
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--max_tokens', type=int, default=4000, help='maximum token length for evaluation')
parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
parser.add_argument('--num_tests', type=int, default=20, help='number of repeat testing for each length')
args = parser.parse_args()
return args
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
def generate_prompt_landmark(n_garbage=60000, seed=666):
"""Generates a text file and inserts an passkey at a random position."""
rnd_state = random.get_state()
random.seed(seed)
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
garbage_inf = " ".join([garbage] * 5000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
final_question = "What is the pass key? The pass key is"
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix,
final_question,
]
random.set_state(rnd_state)
return "\n".join(lines), str(pass_key)
def main(args):
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True, device_map="auto")
total_test_points = args.max_tokens // args.interval
all_accuries = {}
for i in range(total_test_points):
# This is a rough ratio to control the number of texts and tokens
n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024)
passed_tests = 0
total_tokens = 0
for j in range(args.num_tests):
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
response, _ = model.chat(tokenizer, prompt, history=[])
if pass_key in response:
passed_tests += 1
total_tokens += len(tokenizer(prompt).input_ids)
avg_tokens = total_tokens//args.num_tests
accuracy = passed_tests/args.num_tests
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
all_accuries[str(avg_tokens)] = accuracy
print("accuries over tokens", all_accuries)
if __name__ == "__main__":
args = parse_config()
main(args)