mirror of https://github.com/THUDM/ChatGLM2-6B
duzx16
1 year ago
3 changed files with 71 additions and 1 deletions
@ -0,0 +1,10 @@ |
|||||||
|
首先从 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/e84444333b6d434ea7b0) 下载处理好的 C-Eval 数据集,解压到 `evaluation` 目录下。然后运行 |
||||||
|
|
||||||
|
```shell |
||||||
|
cd evaluation |
||||||
|
python evaluate_ceval.py |
||||||
|
``` |
||||||
|
|
||||||
|
这个脚本会在C-Eval的验证集上进行预测并输出准确率。如果想要得到测试集上的结果可以将代码中的 `./CEval/val/**/*.jsonl` 改为 `./CEval/test/**/*.jsonl`,并按照 C-Eval 规定的格式保存结果并在 [官网](https://cevalbenchmark.com/) 上提交。 |
||||||
|
|
||||||
|
汇报的结果使用的是内部的并行测试框架,结果可能会有轻微波动。 |
@ -0,0 +1,60 @@ |
|||||||
|
import os |
||||||
|
import glob |
||||||
|
import re |
||||||
|
import json |
||||||
|
import torch |
||||||
|
import torch.utils.data |
||||||
|
from transformers import AutoTokenizer, AutoModel |
||||||
|
from tqdm import tqdm |
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
||||||
|
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).bfloat16().cuda() |
||||||
|
|
||||||
|
choices = ["A", "B", "C", "D"] |
||||||
|
choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices] |
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(text): |
||||||
|
return "[Round {}]\n\n问:{}\n\n答:".format(1, text) |
||||||
|
|
||||||
|
|
||||||
|
extraction_prompt = '综上所述,ABCD中正确的选项是:' |
||||||
|
|
||||||
|
accuracy_dict, count_dict = {}, {} |
||||||
|
with torch.no_grad(): |
||||||
|
for entry in glob.glob("./CEval/val/**/*.jsonl", recursive=True): |
||||||
|
dataset = [] |
||||||
|
with open(entry, encoding='utf-8') as file: |
||||||
|
for line in file: |
||||||
|
dataset.append(json.loads(line)) |
||||||
|
correct = 0 |
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) |
||||||
|
for batch in tqdm(dataloader): |
||||||
|
texts = batch["inputs_pretokenized"] |
||||||
|
queries = [build_prompt(query) for query in texts] |
||||||
|
inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') |
||||||
|
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512) |
||||||
|
intermediate_outputs = [] |
||||||
|
for idx in range(len(outputs)): |
||||||
|
output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):] |
||||||
|
response = tokenizer.decode(output) |
||||||
|
intermediate_outputs.append(response) |
||||||
|
answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in |
||||||
|
zip(texts, intermediate_outputs)] |
||||||
|
input_tokens = [build_prompt(answer_text) for answer_text in answer_texts] |
||||||
|
inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda') |
||||||
|
outputs = model(**inputs, return_last_logit=True) |
||||||
|
logits = outputs.logits[:, -1] |
||||||
|
logits = logits[:, choice_tokens] |
||||||
|
preds = logits.argmax(dim=-1) |
||||||
|
correct += (preds.cpu() == batch["label"]).sum().item() |
||||||
|
accuracy = correct / len(dataset) |
||||||
|
print(entry, accuracy) |
||||||
|
accuracy_dict[entry] = accuracy |
||||||
|
count_dict[entry] = len(dataset) |
||||||
|
|
||||||
|
acc_total, count_total = 0.0, 0 |
||||||
|
for key in accuracy_dict: |
||||||
|
acc_total += accuracy_dict[key] * count_dict[key] |
||||||
|
count_total += count_dict[key] |
||||||
|
print(acc_total / count_total) |
Loading…
Reference in new issue