mirror of https://github.com/InternLM/InternLM
feat(tools): add passkey retrieval test
parent
e611817442
commit
d2cc5f61e3
|
@ -1,4 +1,4 @@
|
|||
本目录提供辅助模型训练的一些工具,文件结构如下所示:
|
||||
本目录提供辅助模型训练和推理的一些工具,文件结构如下所示:
|
||||
|
||||
```bash
|
||||
├── transformers # 适配hugging face的transformers的一些工具
|
||||
|
@ -6,6 +6,7 @@
|
|||
│ ├── modeling_internlm.py # model适配工具
|
||||
│ ├── tokenization_internlm.py # tokenizer适配工具
|
||||
│ └── convert2hf.py # 模型适配hugging face工具
|
||||
├── passkey_retrieval.py # 长文本检索测试工具
|
||||
└── tokenizer.py # 将原始数据转换成bin和meta文件的工具
|
||||
```
|
||||
|
||||
|
@ -109,3 +110,24 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
|
|||
| -------- | -------------------- |
|
||||
| w/o tool | 34.5 |
|
||||
| w tool | 39.2 |
|
||||
|
||||
# passkey_retrieval.py
|
||||
用于测试模型输入不同长度文本时,提取细节的能力。测试方法来自 [这篇论文](https://arxiv.org/pdf/2305.16300.pdf)。使用方法:
|
||||
|
||||
```bash
|
||||
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
|
||||
|
||||
# 可选参数:
|
||||
# --max-tokens <max_token> 最大输入文本长度(默认:4096)。
|
||||
# --interval <interval> 每隔多大长度测试一轮(默认:1024)。
|
||||
# --num_tests <num_tests> 每轮测多少次推理(默认:20)。
|
||||
```
|
||||
|
||||
以下是使用示例:
|
||||
```bash
|
||||
python3 tools/passkey_retrieval.py
|
||||
```
|
||||
输出是不同 token 长度下检索 passkey 的精度。
|
||||
```bash
|
||||
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
|
||||
```
|
|
@ -6,6 +6,7 @@ This directory provide some tools for model training with the following file str
|
|||
│ ├── modeling_internlm.py # tools for adapting model
|
||||
│ └── tokenization_internlm.py # tools for adapting tokenizer
|
||||
│ └── convert2hf.py # tools for adapting models to Hugging Face's format
|
||||
├── passkey_retrieval.py # tools for testing handle long context
|
||||
└── tokenizer.py # tools for generating `bin` and `meta` file for raw data
|
||||
```
|
||||
|
||||
|
@ -107,3 +108,24 @@ InternLM performance in the GSM8K dataset with and without tools:
|
|||
| -------- | -------------------- |
|
||||
| w/o tool | 34.5 |
|
||||
| w tool | 39.2 |
|
||||
|
||||
# passkey_retrieval.py
|
||||
Test the ability to extract details when inputting long text. This test method comes from [this paper](https://arxiv.org/pdf/2305.16300.pdf).
|
||||
|
||||
```bash
|
||||
python3 tools/passkey_retrieval.py [--max_tokens <max_token>] [--interval <interval>] [--num_tests <num_tests>]
|
||||
|
||||
# Optional parameters:
|
||||
# --max-tokens <max_token> Maximum input text length (default: 4096).
|
||||
# --interval <interval> The length of the test (default: 1024).
|
||||
# --num_tests <num_tests> How many times to test inference per round (default: 20).
|
||||
```
|
||||
|
||||
Below is an example of usage:
|
||||
```bash
|
||||
python3 tools/passkey_retrieval.py
|
||||
```
|
||||
The output is the accuracy of retrieving passkey under different token lengths.
|
||||
```bash
|
||||
accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8}
|
||||
```
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import argparse
|
||||
import random
|
||||
from numpy import random
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import pdb
|
||||
|
||||
def parse_config():
|
||||
parser = argparse.ArgumentParser(description='arg parser')
|
||||
parser.add_argument('--max_tokens', type=int, default=4000, help='maximum token length for evaluation')
|
||||
parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
|
||||
parser.add_argument('--num_tests', type=int, default=20, help='number of repeat testing for each length')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
|
||||
def generate_prompt_landmark(n_garbage=60000, seed=666):
|
||||
"""Generates a text file and inserts an passkey at a random position."""
|
||||
rnd_state = random.get_state()
|
||||
random.seed(seed)
|
||||
n_garbage_prefix = random.randint(0, n_garbage)
|
||||
n_garbage_suffix = n_garbage - n_garbage_prefix
|
||||
|
||||
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
|
||||
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
|
||||
garbage_inf = " ".join([garbage] * 5000)
|
||||
assert len(garbage_inf) >= n_garbage
|
||||
garbage_prefix = garbage_inf[:n_garbage_prefix]
|
||||
garbage_suffix = garbage_inf[:n_garbage_suffix]
|
||||
pass_key = random.randint(1, 50000)
|
||||
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
|
||||
final_question = "What is the pass key? The pass key is"
|
||||
lines = [
|
||||
task_description,
|
||||
garbage_prefix,
|
||||
information_line,
|
||||
garbage_suffix,
|
||||
final_question,
|
||||
]
|
||||
random.set_state(rnd_state)
|
||||
return "\n".join(lines), str(pass_key)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True, device_map="auto")
|
||||
|
||||
total_test_points = args.max_tokens // args.interval
|
||||
all_accuries = {}
|
||||
for i in range(total_test_points):
|
||||
# This is a rough ratio to control the number of texts and tokens
|
||||
n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024)
|
||||
passed_tests = 0
|
||||
total_tokens = 0
|
||||
|
||||
for j in range(args.num_tests):
|
||||
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
|
||||
response, _ = model.chat(tokenizer, prompt, history=[])
|
||||
if pass_key in response:
|
||||
passed_tests += 1
|
||||
total_tokens += len(tokenizer(prompt).input_ids)
|
||||
avg_tokens = total_tokens//args.num_tests
|
||||
accuracy = passed_tests/args.num_tests
|
||||
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
|
||||
all_accuries[str(avg_tokens)] = accuracy
|
||||
print("accuries over tokens", all_accuries)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_config()
|
||||
main(args)
|
Loading…
Reference in New Issue