# Adapted form https://github.com/lm-sys/FastChat/blob/main/fastchat/eval/qa_baseline_gpt35.py # Copyright 2023 LM-SYS@FastChat # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import os import time import concurrent.futures import openai import tqdm import shortuuid import logging from utils import jload, jdump MODEL = 'gpt-3.5-turbo' MAX_API_RETRY = 3 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def get_answer(question: str, max_tokens: int): answer = question prompt = question['instruction'] if question['input'] == "" else question['instruction'] + \ " " + question['input'] for _ in range(MAX_API_RETRY): try: response = openai.ChatCompletion.create( model='gpt-3.5-turbo', messages=[{ 'role': 'system', 'content': 'You are a helpful assistant.' }, { 'role': 'user', 'content': prompt, }], max_tokens=max_tokens, ) answer['output'] = response['choices'][0]['message']['content'] return answer except Exception as e: logger.error(e) time.sleep(1) logger.error(f' Answer {question["id"]} failed after {MAX_API_RETRY} retries.') return answer def evaluate_gpt35(args): questions=jload(args.dataset) logger.info( f' Total number of answers: {len(questions)}.') logger.info( f' Waiting for {args.request_time_gap} seconds before sending the next request.') answers = [] with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_workers) as executor: futures = [] for question in questions: future = executor.submit(get_answer, question, args.max_tokens) futures.append(future) for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): answers.append(future.result()) answers.sort(key=lambda x: x['id']) jdump(answers, os.path.join(args.answer_path, f'gpt35_answers.json')) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate GPT 3.5.') parser.add_argument('--dataset', type=str, default="questions.json") parser.add_argument('--answer_path', type=str, default="answer") parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--openai_key', type=str, default=None) parser.add_argument('--max_tokens', type=int, default=1024) args = parser.parse_args() if args.openai_key is not None: os.environ["OPENAI_API_KEY"] = args.openai_key openai.api_key = os.getenv("OPENAI_API_KEY") evaluate_gpt35(args)