mirror of https://github.com/hpcaitech/ColossalAI
parent
633bac2f58
commit
c4709d34cf
|
@ -373,6 +373,13 @@ Thanks so much to all of our amazing contributors!
|
|||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
||||
- Keep in a sufficiently high running speed
|
||||
|
||||
| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
|
||||
| :-----------: | :------------------: | :------------------: |
|
||||
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
|
||||
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
|
||||
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
|
||||
- Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner.
|
||||
|
||||
## Authors
|
||||
|
||||
Coati is developed by ColossalAI Team:
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
# Evaluation
|
||||
|
||||
In this directory we will introduce how you can evaluate your model with GPT-4.
|
||||
|
||||
## Evaluation Pipeline
|
||||
|
||||
The whole evaluation process undergoes two steps.
|
||||
|
||||
1. Generate answers from different models: Use `generate_gpt35_answers.py` to generate answers of GPT 3.5 and use `generate_answers.py` to generate answers of your own models.
|
||||
2. Evaluate models using GPT 4: Use `evaluate.py` to evaluate model answers with GPT-4.
|
||||
|
||||
### Generate Answers
|
||||
|
||||
To generate answers, you should first format [FastChat's]([FastChat/question.jsonl at main · lm-sys/FastChat (github.com)](https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/table/question.jsonl)) `question.jsonl` file. We do this formatting because we would like to add more questions later and the pipeline for generating new questions may follow that of Self-Instruct and Stanford Alpaca. An example script is given as follows.
|
||||
|
||||
```shell
|
||||
python format_questions.py \
|
||||
--questions_path "path to FastChat's question.jsonl" \
|
||||
--save_path "path to the formatted file" \
|
||||
|
||||
```
|
||||
|
||||
In `generate_answers.py`, the model will generate answers in a batch way and different GPU processes will do inference on different shards of the given questions. Once all GPU process generate its answers, `merge.py` will merge different shards of answers and output a single answer file. Finally, the script will also remove the answer shards. An example script is given as follows.
|
||||
|
||||
```shell
|
||||
device_number=number of your devices
|
||||
model_name="name of your model"
|
||||
model_path="path to your model"
|
||||
dataset="path to the question dataset"
|
||||
answer_path="path to save the model answers"
|
||||
|
||||
torchrun --standalone --nproc_per_node=$device_number generate_answers.py \
|
||||
--model 'llama' \
|
||||
--strategy ddp \
|
||||
--model_path $model_path \
|
||||
--model_name $model_name \
|
||||
--dataset $dataset \
|
||||
--batch_size 8 \
|
||||
--max_datasets_size 80 \
|
||||
--answer_path $answer_path \
|
||||
--max_length 512
|
||||
|
||||
python merge.py \
|
||||
--model_name $model_name \
|
||||
--shards $device_number \
|
||||
--answer_path $answer_path \
|
||||
|
||||
for (( i=0; i<device_number; i++ )) do
|
||||
rm -rf "${answer_path}/${model_name}_answers_rank${i}.json"
|
||||
done
|
||||
|
||||
```
|
||||
|
||||
`generate_gpt35_answers.py` will generate answers of GPT-3.5 An example script is given as follows.
|
||||
|
||||
```shell
|
||||
python generate_gpt35_answers.py \
|
||||
--dataset "path to the question dataset" \
|
||||
--answer_path "path to answer folder" \
|
||||
--num_workers 4 \
|
||||
--openai_key "your openai key" \
|
||||
--max_tokens 512 \
|
||||
|
||||
```
|
||||
|
||||
### Evaluate Answers
|
||||
|
||||
In `evaluate.py`, GPT-4 will help review and score answers of two different models. Here `Model 1` refers to the first model you specify in the `--answer_file_list` and `Model 2` refers to the second model. The script will finally print several metrics and output corresponding JSON files.
|
||||
|
||||
The metrics include:
|
||||
|
||||
- `Invalid Count`: The number of reviews where the program fail to parse the score pair.
|
||||
- `Better Count`: The number of reviews where Model 2 receives a higher score.
|
||||
- `Worse Count`: The number of reviews where Model 2 receives a lower score.
|
||||
- `Tie Count`: The number of reviews where two models play to a tie.
|
||||
- `Win Rate of Model 2`: Win rate of Model 2.
|
||||
- `Model 1 Average Score`: Average score of Model 1.
|
||||
- `Model 2 Average Score`: Average score of Model 2.
|
||||
|
||||
Other than the `review` and `result` file which include all reviews, the output files also include `invalid`, `better`, `worse` and `tie` JSON file which only include the corresponding reviews.
|
||||
|
||||
```shell
|
||||
python evaluate.py \
|
||||
--answer_file_list "path to answers of model 1" "path to answers of model 2" \
|
||||
--prompt_file "path to prompt file" \
|
||||
--reviewer_file "path to reviewer file" \
|
||||
--output_folder "path to output folder" \
|
||||
--openai_key "your openai key" \
|
||||
--model "the gpt model" \
|
||||
--num_workers 8 \
|
||||
--max_tokens 512 \
|
||||
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
We compare our model with alpaca and vicuna. The results is shown below. Please note that the better cases don't add to 80 because there are reviews the program can't successfully parse to get the score pair. Our Coati-7B model performs better than Alpaca-7B. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner.
|
||||
|
||||
| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
|
||||
| :-----------: | :------------------: | :------------------: |
|
||||
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
|
||||
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
|
||||
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
|
||||
|
||||
We would like to mention that the evaluation of model answers using the GPT-3.5 model is not reliable. GPT-3.5 tends to give a higher score to the second answer (`{answer2}` in the prompt). In our evaluation which uses GPT-4, we still swap the two model answers. As can be seen from the table, GPT-4 can generate consistent results and it is more unbiased than GPT-3.5.
|
||||
|
||||
## Data Format
|
||||
|
||||
### Questions
|
||||
|
||||
We store questions in `questions.json`. The JSON file contains one list. Each element in the list is a question record.
|
||||
|
||||
A question record has the following field:
|
||||
|
||||
* `category` (str): The category of the question.
|
||||
* `instruction` (str): The question.
|
||||
* `input` (str): This is empty if you only use [FastChat's]([FastChat/question.jsonl at main · lm-sys/FastChat (github.com)](https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/table/question.jsonl)) questions.
|
||||
* `output` (str): This is empty.
|
||||
* `id` (int): The question id.
|
||||
|
||||
### Answers
|
||||
|
||||
We store model answers in `{model_name}_answers.json`. The JSON file contains one list. Each element in the list is an answer record to one question.
|
||||
|
||||
An answer record has the following field:
|
||||
|
||||
* `category` (str): The category of the question.
|
||||
* `instruction` (str): The question.
|
||||
* `input` (str): This is empty if you only use [FastChat's]([FastChat/question.jsonl at main · lm-sys/FastChat (github.com)](https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/table/question.jsonl)) questions.
|
||||
* `output` (str): The answer to the question.
|
||||
* `id` (int): The question id.
|
||||
|
||||
### Results
|
||||
|
||||
We store evaluation results in `results.json`. The JSON file contains one dictionary. The key in the dictionary is formatted as `{model 1}_vs_{model 2}` and the value is also a dictionary contains metrics about the evaluation.
|
||||
|
||||
The value has the following field:
|
||||
|
||||
* `model` (list): The names of the two models.
|
||||
* `better` (int): The number of reviews where Model 2 receives a higher score.
|
||||
* `worse` (int): The number of reviews where Model 2 receives a lower score.
|
||||
* `tie` (int): The number of reviews where two models play to a tie.
|
||||
* `win_rate` (float): Win rate of Model 2.
|
||||
* `score` (list): Average score of the two models.
|
||||
|
||||
### Better, Worse, Tie, Invalid, Review
|
||||
|
||||
To help better compare the model answers, we store JSON files whose name ends with `_better`, `_worse`, `_tie`, `_invalid` or `_review`. Each JSON file contains one list. Each element in the list is a record of better, worse, tie, invalid or all cases.
|
||||
|
||||
A record has the following field:
|
||||
|
||||
* `review_id` (str): Random UUID, not in use.
|
||||
* `id` (int): The question id.
|
||||
* `reviewer_id` (int): A unique ID for a reviewer. Different reviewer id use different prompts.
|
||||
* `metadata` (dict): It is empty.
|
||||
* `review` (str): GPT-4 's review.
|
||||
* `score` (list): The scores of two models.
|
||||
|
||||
### Prompts
|
||||
|
||||
The data format is the same with [FastChat's]([FastChat/prompt.jsonl at main · lm-sys/FastChat (github.com)](https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/table/prompt.jsonl)) prompts.
|
||||
|
||||
### Reviewer
|
||||
|
||||
The data format is the same with [FastChat's]([FastChat/reviewer.jsonl at main · lm-sys/FastChat (github.com)](https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/table/reviewer.jsonl)) reviewers.
|
||||
|
||||
## Plan
|
||||
|
||||
- [ ] Extend the questions
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@misc{vicuna2023,
|
||||
title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality},
|
||||
url = {https://vicuna.lmsys.org},
|
||||
author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.},
|
||||
month = {March},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,256 @@
|
|||
# Adapted form https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/eval_gpt_review.py
|
||||
# Copyright 2023 LM-SYS@FastChat
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import concurrent.futures
|
||||
|
||||
import openai
|
||||
import tqdm
|
||||
import shortuuid
|
||||
import logging
|
||||
|
||||
from utils import jload, jdump, get_json_list
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
|
||||
def get_eval(sys_prompt, user_prompt: str, answer_id: int, max_tokens: int, model: str):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for _ in range(MAX_API_RETRY):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
'role': 'system',
|
||||
'content': sys_prompt
|
||||
}, {
|
||||
'role': 'user',
|
||||
'content': user_prompt,
|
||||
}],
|
||||
temperature=0.2,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
review = response['choices'][0]['message']['content']
|
||||
return {"review": review, 'id': answer_id}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
time.sleep(1)
|
||||
logger.error(f' Review {answer_id} failed after {MAX_API_RETRY} retries.')
|
||||
return 'error'
|
||||
|
||||
|
||||
def parse_score(review):
|
||||
try:
|
||||
pattern = re.compile('([0-9]|10) out of 10')
|
||||
sp = re.findall(pattern, review)
|
||||
if len(re.findall(pattern, review)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
pattern = re.compile('a score of ([0-9]|10)')
|
||||
sp = re.findall(pattern, review)
|
||||
if len(re.findall(pattern, review)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
pattern = re.compile('([0-9]|10)/10')
|
||||
sp = re.findall(pattern, review)
|
||||
if len(re.findall(pattern, review)) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
|
||||
score_pair = review.split('\n')[0]
|
||||
score_pair = score_pair.replace(',', ' ')
|
||||
sp = score_pair.split(' ')
|
||||
if len(sp) == 2:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
else:
|
||||
raise Exception('Invalid score pair.')
|
||||
except Exception as e:
|
||||
return [-1, -1]
|
||||
|
||||
|
||||
def gen_prompt(reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2):
|
||||
reviewer_idx = 0
|
||||
for idx, reviewer in enumerate(reviewer_jsons):
|
||||
if reviewer['category'] == cat:
|
||||
reviewer_idx = idx
|
||||
break
|
||||
prompt_id = reviewer_jsons[reviewer_idx]['prompt_id']
|
||||
prompt_json = prompt_jsons[prompt_id-1]
|
||||
assert prompt_json['prompt_id'] == prompt_id
|
||||
|
||||
sys_prompt = prompt_json['system_prompt']
|
||||
prompt_template = prompt_json['prompt_template']
|
||||
defaults = prompt_json['defaults']
|
||||
prompt = prompt_template.format(
|
||||
question=ques, answer_1=ans1, answer_2=ans2, **defaults)
|
||||
|
||||
return sys_prompt, prompt, reviewer_idx+1
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
answer1_jsons = jload(args.answer_file_list[0])
|
||||
answer2_jsons = jload(args.answer_file_list[1])
|
||||
reviewer_jsons = get_json_list(args.reviewer_file)
|
||||
prompt_jsons = get_json_list(args.prompt_file)
|
||||
|
||||
assert len(answer1_jsons) == len(answer2_jsons)
|
||||
|
||||
handles = []
|
||||
review_jsons = []
|
||||
|
||||
total_len = len(answer1_jsons)
|
||||
question_idx_list = list(range(total_len))
|
||||
|
||||
logger.info(
|
||||
f' Total number of answers: {len(answer2_jsons)}.')
|
||||
|
||||
reviews = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor:
|
||||
futures = []
|
||||
for i in question_idx_list:
|
||||
assert answer1_jsons[i]['id'] == answer2_jsons[i]['id']
|
||||
answer_id = answer1_jsons[i]['id']
|
||||
|
||||
ques = answer1_jsons[i]['instruction'] if answer1_jsons[i]['input'] == "" else answer1_jsons[i]['instuction'] + \
|
||||
" " + answer1_jsons[i]['input']
|
||||
cat = answer1_jsons[i]['category']
|
||||
ans1 = answer1_jsons[i]['output']
|
||||
ans2 = answer2_jsons[i]['output']
|
||||
|
||||
sys_prompt, prompt, reviewer_id = gen_prompt(
|
||||
reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2)
|
||||
|
||||
review_id = shortuuid.uuid()
|
||||
review_jsons.append({
|
||||
'review_id': review_id,
|
||||
'id': answer_id,
|
||||
'reviewer_id': reviewer_id,
|
||||
'metadata': {}
|
||||
})
|
||||
|
||||
future = executor.submit(
|
||||
get_eval, sys_prompt, prompt, answer_id, args.max_tokens, args.model)
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
||||
reviews.append(future.result())
|
||||
|
||||
reviews.sort(key=lambda x: x['id'])
|
||||
review_jsons.sort(key=lambda x: x['id'])
|
||||
|
||||
ans1_score = 0
|
||||
ans2_score = 0
|
||||
better_count = 0
|
||||
worse_count = 0
|
||||
tie_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
better_file = []
|
||||
worse_file = []
|
||||
tie_file = []
|
||||
invalid_file = []
|
||||
output_review_file = []
|
||||
|
||||
for idx, review in enumerate(reviews):
|
||||
scores = parse_score(review['review'])
|
||||
review_jsons[idx]['review'] = review['review']
|
||||
review_jsons[idx]['score'] = scores
|
||||
|
||||
if scores[0] == -1 and scores[1] == -1:
|
||||
invalid_count += 1
|
||||
invalid_file.append(review_jsons[idx])
|
||||
logger.info(f' Invalid score pair: {review_jsons[idx]["id"]}.')
|
||||
else:
|
||||
if scores[0] > scores[1]:
|
||||
worse_count += 1
|
||||
worse_file.append(review_jsons[idx])
|
||||
elif scores[0] < scores[1]:
|
||||
better_count += 1
|
||||
better_file.append(review_jsons[idx])
|
||||
else:
|
||||
tie_count += 1
|
||||
tie_file.append(review_jsons[idx])
|
||||
ans1_score += scores[0]
|
||||
ans2_score += scores[1]
|
||||
|
||||
output_review_file.append(review_jsons[idx])
|
||||
|
||||
better_file.sort(key=lambda x: x['id'])
|
||||
worse_file.sort(key=lambda x: x['id'])
|
||||
tie_file.sort(key=lambda x: x['id'])
|
||||
invalid_file.sort(key=lambda x: x['id'])
|
||||
output_review_file.sort(key=lambda x: x['id'])
|
||||
|
||||
name1 = os.path.basename(args.answer_file_list[0]).split("_answers")[0]
|
||||
name2 = os.path.basename(args.answer_file_list[1]).split("_answers")[0]
|
||||
prefix = f"{name1}_vs_{name2}"
|
||||
|
||||
jdump(better_file, os.path.join(
|
||||
args.output_folder, prefix, f"{prefix}_better.json"))
|
||||
jdump(worse_file, os.path.join(
|
||||
args.output_folder, prefix, f"{prefix}_worse.json"))
|
||||
jdump(tie_file, os.path.join(
|
||||
args.output_folder, prefix, f"{prefix}_tie.json"))
|
||||
jdump(invalid_file, os.path.join(
|
||||
args.output_folder, prefix, f"{prefix}_invalid.json"))
|
||||
jdump(output_review_file, os.path.join(
|
||||
args.output_folder, prefix, f"{prefix}_review.json"))
|
||||
|
||||
if os.path.exists(os.path.join(args.output_folder, "results.json")):
|
||||
results = jload(os.path.join(args.output_folder, "results.json"))
|
||||
else:
|
||||
results = {}
|
||||
results[prefix] = {'model': [name1, name2], 'better': better_count, 'worse': worse_count, 'tie': tie_count, 'win_rate': better_count /
|
||||
(len(reviews)-invalid_count), 'score': [ans1_score/(len(reviews)-invalid_count), ans2_score/(len(reviews)-invalid_count)]}
|
||||
jdump(results, os.path.join(args.output_folder, "results.json"))
|
||||
|
||||
logger.info(f' Total {invalid_count} invalid score pair(s).')
|
||||
logger.info(f' Model {name2} has {better_count} better answer(s).')
|
||||
logger.info(f' Model {name2} has {worse_count} worse answer(s).')
|
||||
logger.info(f' {tie_count} answer(s) play(s) to a tie.')
|
||||
logger.info(
|
||||
f' Win rate of model {name2}: {better_count/(len(reviews)-invalid_count):.2f}')
|
||||
logger.info(
|
||||
f' Model {name1} average score: {ans1_score/(len(reviews)-invalid_count):.2f}')
|
||||
logger.info(
|
||||
f' Model {name2} average score: {ans2_score/(len(reviews)-invalid_count):.2f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Model evaluation.')
|
||||
parser.add_argument('--answer_file_list', nargs='+', default=[])
|
||||
parser.add_argument('--prompt_file')
|
||||
parser.add_argument('--reviewer_file')
|
||||
parser.add_argument('--output_folder', type=str, default="./output")
|
||||
parser.add_argument('--openai_key', type=str, default=None)
|
||||
parser.add_argument('--model', type=str, default="gpt-4")
|
||||
parser.add_argument('--num_workers', type=int, default=8)
|
||||
parser.add_argument('--max_tokens', type=int, default=512,
|
||||
help='maximum number of tokens produced in the output')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.openai_key is not None:
|
||||
os.environ["OPENAI_API_KEY"] = args.openai_key
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
evaluate(args)
|
|
@ -0,0 +1,9 @@
|
|||
python evaluate.py \
|
||||
--answer_file_list "path to answers of model 1" "path to answers of model 2" \
|
||||
--prompt_file "path to prompt file" \
|
||||
--reviewer_file "path to reviewer file" \
|
||||
--output_folder "path to output folder" \
|
||||
--openai_key "your openai key" \
|
||||
--model "gpt-4" \
|
||||
--num_workers 8 \
|
||||
--max_tokens 512 \
|
|
@ -0,0 +1,31 @@
|
|||
import argparse
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
|
||||
from utils import jdump, get_json_list
|
||||
|
||||
|
||||
def format_questions(args):
|
||||
questions = get_json_list(args.questions_path)
|
||||
keys=questions[0].keys()
|
||||
|
||||
formatted_questions=copy.deepcopy(questions)
|
||||
for i in range(len(formatted_questions)):
|
||||
formatted_questions[i]['instruction']=questions[i]['text']
|
||||
formatted_questions[i]['input']=""
|
||||
formatted_questions[i]['output']=""
|
||||
formatted_questions[i]['id']=questions[i]['question_id']
|
||||
for key in keys:
|
||||
if key=="category":
|
||||
continue
|
||||
del formatted_questions[i][key]
|
||||
|
||||
jdump(formatted_questions, args.save_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--questions_path', type=str, default='table/question.jsonl')
|
||||
parser.add_argument('--save_path', type=str, default="table/questions.json")
|
||||
args = parser.parse_args()
|
||||
format_questions(args)
|
|
@ -0,0 +1,3 @@
|
|||
python format_questions.py \
|
||||
--questions_path "path to FastChat's question.jsonl" \
|
||||
--save_path "path to the formatted file" \
|
|
@ -0,0 +1,173 @@
|
|||
import argparse
|
||||
import os
|
||||
import random
|
||||
import copy
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from coati.models.bloom import BLOOMActor
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.opt import OPTActor
|
||||
from coati.models.roberta import RoBERTaActor
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from transformers import AutoTokenizer, RobertaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from utils import jload, jdump, is_rank_0
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
PROMPT_DICT = {
|
||||
"prompt_input":
|
||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
||||
}
|
||||
|
||||
|
||||
def generate(args):
|
||||
# torch.cuda.set_per_process_memory_fraction(0.4)
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.model_path).to(
|
||||
torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.model_path).to(
|
||||
torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.model_path).to(
|
||||
torch.cuda.current_device())
|
||||
elif args.model == 'roberta':
|
||||
actor = RoBERTaActor(pretrained=args.model_path).to(
|
||||
torch.cuda.current_device())
|
||||
elif args.model == 'llama':
|
||||
actor = LlamaActor(pretrained=args.model_path).to(
|
||||
torch.float16).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
elif args.model == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
elif args.model == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
questions = []
|
||||
if args.max_datasets_size is not None:
|
||||
questions = random.sample(jload(args.dataset), args.max_datasets_size)
|
||||
if is_rank_0():
|
||||
logger.info(
|
||||
f"Limiting dataset to {args.max_datasets_size} examples.")
|
||||
questions = questions[rank:args.max_datasets_size:world_size]
|
||||
|
||||
answers = copy.deepcopy(questions)
|
||||
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get(
|
||||
"input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in questions
|
||||
]
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
input_ids_list = []
|
||||
|
||||
for string in sources:
|
||||
input_ids = tokenizer.encode(string, return_tensors='pt').squeeze(0)
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
bar = tqdm(range(math.ceil(len(input_ids_list)/args.batch_size)),
|
||||
desc=f'steps', disable=not is_rank_0())
|
||||
|
||||
actor.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(input_ids_list), args.batch_size):
|
||||
batch = input_ids_list[i:i+args.batch_size]
|
||||
batch = [i.flip(dims=[0]) for i in batch]
|
||||
batch = torch.nn.utils.rnn.pad_sequence(batch,
|
||||
batch_first=True,
|
||||
padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0).to(torch.cuda.current_device())
|
||||
batch = batch.flip(dims=[1])
|
||||
attention_mask = batch.ne(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)
|
||||
|
||||
outputs = actor.model.generate(batch, attention_mask=attention_mask,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
for j in range(batch.size(0)):
|
||||
answers[i +
|
||||
j]['output'] = outputs[j].split("### Response:")[1].strip()
|
||||
|
||||
bar.update()
|
||||
|
||||
jdump(answers, os.path.join(args.answer_path,
|
||||
f'{args.model_name}_answers_rank{rank}.json'))
|
||||
|
||||
if is_rank_0():
|
||||
logger.info(
|
||||
f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini',
|
||||
'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2',
|
||||
choices=['gpt2', 'bloom', 'opt', 'roberta', 'llama'])
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--model_name', type=str, default='model')
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_datasets_size', type=int, default=None)
|
||||
parser.add_argument('--answer_path', type=str, default="answer")
|
||||
parser.add_argument('--max_length', type=int, default=1024)
|
||||
args = parser.parse_args()
|
||||
generate(args)
|
|
@ -0,0 +1,25 @@
|
|||
device_number=number of your devices
|
||||
model_name="name of your model"
|
||||
model_path="path to your model"
|
||||
dataset="path to the question dataset"
|
||||
answer_path="path to save the model answers"
|
||||
|
||||
torchrun --standalone --nproc_per_node=$device_number generate_answers.py \
|
||||
--model 'llama' \
|
||||
--strategy ddp \
|
||||
--model_path $model_path \
|
||||
--model_name $model_name \
|
||||
--dataset $dataset \
|
||||
--batch_size 8 \
|
||||
--max_datasets_size 80 \
|
||||
--answer_path $answer_path \
|
||||
--max_length 512
|
||||
|
||||
python merge.py \
|
||||
--model_name $model_name \
|
||||
--shards $device_number \
|
||||
--answer_path $answer_path \
|
||||
|
||||
for (( i=0; i<device_number; i++ )) do
|
||||
rm -rf "${answer_path}/${model_name}_answers_rank${i}.json"
|
||||
done
|
|
@ -0,0 +1,98 @@
|
|||
# Adapted form https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/qa_baseline_gpt35.py
|
||||
# Copyright 2023 LM-SYS@FastChat
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import concurrent.futures
|
||||
|
||||
import openai
|
||||
import tqdm
|
||||
import shortuuid
|
||||
import logging
|
||||
|
||||
from utils import jload, jdump
|
||||
|
||||
MODEL = 'gpt-3.5-turbo'
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_answer(question: str, max_tokens: int):
|
||||
answer = question
|
||||
prompt = question['instruction'] if question['input'] == "" else question['instuction'] + \
|
||||
" " + question['input']
|
||||
for _ in range(MAX_API_RETRY):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model='gpt-3.5-turbo',
|
||||
messages=[{
|
||||
'role': 'system',
|
||||
'content': 'You are a helpful assistant.'
|
||||
}, {
|
||||
'role': 'user',
|
||||
'content': prompt,
|
||||
}],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
answer['output'] = response['choices'][0]['message']['content']
|
||||
return answer
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
time.sleep(1)
|
||||
logger.error(f' Answer {question["id"]} failed after {MAX_API_RETRY} retries.')
|
||||
return answer
|
||||
|
||||
def evaluate_gpt35(args):
|
||||
questions=jload(args.dataset)
|
||||
|
||||
logger.info(
|
||||
f' Total number of answers: {len(questions)}.')
|
||||
logger.info(
|
||||
f' Waiting for {args.request_time_gap} seconds before sending the next request.')
|
||||
|
||||
answers = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor:
|
||||
futures = []
|
||||
for question in questions:
|
||||
future = executor.submit(get_answer, question, args.max_tokens)
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
||||
answers.append(future.result())
|
||||
|
||||
answers.sort(key=lambda x: x['id'])
|
||||
|
||||
jdump(answers, os.path.join(args.answer_path,
|
||||
f'gpt35_answers.json'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluate GPT 3.5.')
|
||||
parser.add_argument('--dataset', type=str, default="questions.json")
|
||||
parser.add_argument('--answer_path', type=str, default="answer")
|
||||
parser.add_argument('--num_workers', type=int, default=4)
|
||||
parser.add_argument('--openai_key', type=str, default=None)
|
||||
parser.add_argument('--max_tokens', type=int, default=1024)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.openai_key is not None:
|
||||
os.environ["OPENAI_API_KEY"] = args.openai_key
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
evaluate_gpt35(args)
|
|
@ -0,0 +1,6 @@
|
|||
python generate_gpt35_answers.py \
|
||||
--dataset "path to the question dataset" \
|
||||
--answer_path "path to answer folder" \
|
||||
--num_workers 4 \
|
||||
--openai_key "your openai key" \
|
||||
--max_tokens 512 \
|
|
@ -0,0 +1,25 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
from utils import jload, jdump
|
||||
|
||||
|
||||
def generate(args):
|
||||
dataset = []
|
||||
for i in range(args.shards):
|
||||
shard = jload(os.path.join(args.answer_path,
|
||||
f'{args.model_name}_answers_rank{i}.json'))
|
||||
dataset.extend(shard)
|
||||
|
||||
dataset.sort(key=lambda x: x['id'])
|
||||
jdump(dataset, os.path.join(args.answer_path,
|
||||
f'{args.model_name}_answers.json'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str, default='model')
|
||||
parser.add_argument('--shards', type=int, default=4)
|
||||
parser.add_argument('--answer_path', type=str, default="answer")
|
||||
args = parser.parse_args()
|
||||
generate(args)
|
|
@ -0,0 +1,53 @@
|
|||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
def _make_w_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f_dirname = os.path.dirname(f)
|
||||
if f_dirname != "":
|
||||
os.makedirs(f_dirname, exist_ok=True)
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
def _make_r_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
def jdump(obj, f, mode="w", indent=4, default=str):
|
||||
"""Dump a str or dictionary to a file in json format.
|
||||
Args:
|
||||
obj: An object to be written.
|
||||
f: A string path to the location on disk.
|
||||
mode: Mode for opening the file.
|
||||
indent: Indent for storing json dictionaries.
|
||||
default: A function to handle non-serializable entries; defaults to `str`.
|
||||
"""
|
||||
f = _make_w_io_base(f, mode)
|
||||
if isinstance(obj, (dict, list)):
|
||||
json.dump(obj, f, indent=indent, default=default)
|
||||
elif isinstance(obj, str):
|
||||
f.write(obj)
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(obj)}")
|
||||
f.close()
|
||||
|
||||
def jload(f, mode="r"):
|
||||
"""Load a .json file into a dictionary."""
|
||||
f = _make_r_io_base(f, mode)
|
||||
jdict = json.load(f)
|
||||
f.close()
|
||||
return jdict
|
||||
|
||||
def get_json_list(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
json_list = []
|
||||
for line in f:
|
||||
json_list.append(json.loads(line))
|
||||
return json_list
|
Loading…
Reference in New Issue