mirror of https://github.com/hpcaitech/ColossalAI
26 lines
749 B
Python
26 lines
749 B
Python
import argparse
|
|
import os
|
|
|
|
from utils import jload, jdump
|
|
|
|
|
|
def generate(args):
|
|
dataset = []
|
|
for i in range(args.shards):
|
|
shard = jload(os.path.join(args.answer_path,
|
|
f'{args.model_name}_answers_rank{i}.json'))
|
|
dataset.extend(shard)
|
|
|
|
dataset.sort(key=lambda x: x['id'])
|
|
jdump(dataset, os.path.join(args.answer_path,
|
|
f'{args.model_name}_answers.json'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_name', type=str, default='model')
|
|
parser.add_argument('--shards', type=int, default=4)
|
|
parser.add_argument('--answer_path', type=str, default="answer")
|
|
args = parser.parse_args()
|
|
generate(args)
|