mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
912 B
27 lines
912 B
import argparse |
|
import json |
|
import random |
|
|
|
random.seed(42) |
|
|
|
|
|
def sample(args): |
|
with open(args.dataset_path, mode="r") as f: |
|
dataset_list = json.load(f) |
|
|
|
sampled_dataset = [ |
|
{"instruction": sample["instruction"], "id": idx} |
|
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) |
|
] |
|
|
|
with open(args.save_path, mode="w") as f: |
|
json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset") |
|
parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset") |
|
parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset") |
|
args = parser.parse_args() |
|
sample(args)
|
|
|