mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.7 KiB
83 lines
2.7 KiB
import argparse
|
|
import json
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
def generate_alpaca():
|
|
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
|
|
conversation_dataset = []
|
|
dataset = load_dataset("tatsu-lab/alpaca", split="train")
|
|
|
|
instructions = dataset["instruction"]
|
|
inputs = dataset["input"]
|
|
outputs = dataset["output"]
|
|
|
|
assert len(instructions) == len(inputs) == len(outputs)
|
|
|
|
for idx in range(len(instructions)):
|
|
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
|
|
human = {"from": "human", "value": human_utterance}
|
|
|
|
gpt_utterance = outputs[idx]
|
|
gpt = {"from": "gpt", "value": gpt_utterance}
|
|
|
|
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
|
|
conversation_dataset.append(conversation)
|
|
|
|
return conversation_dataset
|
|
|
|
|
|
def generate_sharegpt():
|
|
# ShareGPT data requires less processing.
|
|
conversation_dataset = []
|
|
dataset = load_dataset(
|
|
"anon8231489123/ShareGPT_Vicuna_unfiltered",
|
|
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
|
|
split="train",
|
|
)
|
|
|
|
conversations = dataset["conversations"]
|
|
|
|
for idx in range(len(conversations)):
|
|
for conv in conversations[idx]:
|
|
# We don't need markdown and text value.
|
|
del conv["markdown"]
|
|
del conv["text"]
|
|
|
|
conversation = dict(
|
|
type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
|
|
)
|
|
conversation_dataset.append(conversation)
|
|
|
|
return conversation_dataset
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
default="All",
|
|
choices=["Alpaca", "ShareGPT", "All"],
|
|
help="which dataset to convert, All will combine Alpaca and ShareGPT",
|
|
)
|
|
parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
|
|
args = parser.parse_args()
|
|
|
|
conversation_dataset = []
|
|
|
|
if args.dataset == "Alpaca":
|
|
conversation_dataset.extend(generate_alpaca())
|
|
elif args.dataset == "ShareGPT":
|
|
conversation_dataset.extend(generate_sharegpt())
|
|
else:
|
|
conversation_dataset.extend(generate_alpaca())
|
|
conversation_dataset.extend(generate_sharegpt())
|
|
|
|
for idx, sample in enumerate(conversation_dataset):
|
|
sample["id"] = idx + 1
|
|
|
|
with open(args.save_path, mode="w") as f:
|
|
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
|