From d2cc5f61e3500c03d4178ea20de8135b025018ad Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Thu, 19 Oct 2023 16:18:42 +0800 Subject: [PATCH] feat(tools): add passkey retrieval test --- tools/README.md | 26 ++++++++++++-- tools/README_EN.md | 22 ++++++++++++ tools/passkey_retrieval.py | 72 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 tools/passkey_retrieval.py diff --git a/tools/README.md b/tools/README.md index 0c78a56..8bb8c0a 100644 --- a/tools/README.md +++ b/tools/README.md @@ -1,11 +1,12 @@ -本目录提供辅助模型训练的一些工具,文件结构如下所示: +本目录提供辅助模型训练和推理的一些工具,文件结构如下所示: ```bash ├── transformers # 适配hugging face的transformers的一些工具 │ ├── configuration_internlm.py # config适配工具 │ ├── modeling_internlm.py # model适配工具 │ ├── tokenization_internlm.py # tokenizer适配工具 -│ └── convert2hf.py # 模型适配hugging face工具 +│ └── convert2hf.py # 模型适配hugging face工具 +├── passkey_retrieval.py # 长文本检索测试工具 └── tokenizer.py # 将原始数据转换成bin和meta文件的工具 ``` @@ -109,3 +110,24 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现: | -------- | -------------------- | | w/o tool | 34.5 | | w tool | 39.2 | + +# passkey_retrieval.py +用于测试模型输入不同长度文本时,提取细节的能力。测试方法来自 [这篇论文](https://arxiv.org/pdf/2305.16300.pdf)。使用方法: + +```bash +python3 tools/passkey_retrieval.py [--max_tokens ] [--interval ] [--num_tests ] + +# 可选参数: +# --max-tokens 最大输入文本长度(默认:4096)。 +# --interval 每隔多大长度测试一轮(默认:1024)。 +# --num_tests 每轮测多少次推理(默认:20)。 +``` + +以下是使用示例: +```bash +python3 tools/passkey_retrieval.py +``` +输出是不同 token 长度下检索 passkey 的精度。 +```bash +accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8} +``` \ No newline at end of file diff --git a/tools/README_EN.md b/tools/README_EN.md index 3105146..190fbf2 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -6,6 +6,7 @@ This directory provide some tools for model training with the following file str │ ├── modeling_internlm.py # tools for adapting model │ └── tokenization_internlm.py # tools for adapting tokenizer │ └── convert2hf.py # tools for adapting models to Hugging Face's format +├── passkey_retrieval.py # tools for testing handle long context └── tokenizer.py # tools for generating `bin` and `meta` file for raw data ``` @@ -107,3 +108,24 @@ InternLM performance in the GSM8K dataset with and without tools: | -------- | -------------------- | | w/o tool | 34.5 | | w tool | 39.2 | + +# passkey_retrieval.py +Test the ability to extract details when inputting long text. This test method comes from [this paper](https://arxiv.org/pdf/2305.16300.pdf). + +```bash +python3 tools/passkey_retrieval.py [--max_tokens ] [--interval ] [--num_tests ] + +# Optional parameters: +# --max-tokens Maximum input text length (default: 4096). +# --interval The length of the test (default: 1024). +# --num_tests How many times to test inference per round (default: 20). +``` + +Below is an example of usage: +```bash +python3 tools/passkey_retrieval.py +``` +The output is the accuracy of retrieving passkey under different token lengths. +```bash +accuries over tokens {'881': 1.0, '1973': 0.8, '2792': 1.0, '3885': 0.8} +``` diff --git a/tools/passkey_retrieval.py b/tools/passkey_retrieval.py new file mode 100644 index 0000000..791f942 --- /dev/null +++ b/tools/passkey_retrieval.py @@ -0,0 +1,72 @@ +import argparse +import random +from numpy import random +from transformers import AutoTokenizer, AutoModelForCausalLM +import pdb + +def parse_config(): + parser = argparse.ArgumentParser(description='arg parser') + parser.add_argument('--max_tokens', type=int, default=4000, help='maximum token length for evaluation') + parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation') + parser.add_argument('--num_tests', type=int, default=20, help='number of repeat testing for each length') + + args = parser.parse_args() + return args + +# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py +def generate_prompt_landmark(n_garbage=60000, seed=666): + """Generates a text file and inserts an passkey at a random position.""" + rnd_state = random.get_state() + random.seed(seed) + n_garbage_prefix = random.randint(0, n_garbage) + n_garbage_suffix = n_garbage - n_garbage_prefix + + task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." + garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + garbage_inf = " ".join([garbage] * 5000) + assert len(garbage_inf) >= n_garbage + garbage_prefix = garbage_inf[:n_garbage_prefix] + garbage_suffix = garbage_inf[:n_garbage_suffix] + pass_key = random.randint(1, 50000) + information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key." + final_question = "What is the pass key? The pass key is" + lines = [ + task_description, + garbage_prefix, + information_line, + garbage_suffix, + final_question, + ] + random.set_state(rnd_state) + return "\n".join(lines), str(pass_key) + + +def main(args): + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained("internlm/internlm-20b-chat", trust_remote_code=True, device_map="auto") + + total_test_points = args.max_tokens // args.interval + all_accuries = {} + for i in range(total_test_points): + # This is a rough ratio to control the number of texts and tokens + n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024) + passed_tests = 0 + total_tokens = 0 + + for j in range(args.num_tests): + prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j) + response, _ = model.chat(tokenizer, prompt, history=[]) + if pass_key in response: + passed_tests += 1 + total_tokens += len(tokenizer(prompt).input_ids) + avg_tokens = total_tokens//args.num_tests + accuracy = passed_tests/args.num_tests + print("accuracy on the token length %d is %f"%(avg_tokens, accuracy)) + all_accuries[str(avg_tokens)] = accuracy + print("accuries over tokens", all_accuries) + + +if __name__ == "__main__": + args = parse_config() + main(args)