You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py

274 lines
12 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare dataset scripts
Usage:
- For SFT dataset preparation (SFT)
python prepare_dataset.py --type sft \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For prompt dataset preparation (PPO)
python prepare_dataset.py --type prompt \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
- For Preference dataset preparation (DPO and Reward model training)
python prepare_dataset.py --type preference \
--data_input_dirs /PATH/TO/SFT/DATASET \
--conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \
--tokenizer_dir "" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
"""
import argparse
import json
import math
import os
import random
import time
from multiprocessing import cpu_count
from coati.dataset import setup_conversation_template, tokenize_kto, tokenize_prompt, tokenize_rlhf, tokenize_sft
from datasets import dataset_dict, load_dataset
from transformers import AutoTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--type",
type=str,
required=True,
default=None,
choices=["sft", "prompt", "preference", "kto"],
help="Type of dataset, chose from 'sft', 'prompt', 'preference'. 'kto'",
)
parser.add_argument(
"--data_input_dirs",
type=str,
required=True,
default=None,
help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
)
parser.add_argument(
"--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
)
parser.add_argument(
"--conversation_template_config",
type=str,
default="conversation_template_config",
help="Path \
to save conversation template config files.",
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
"--data_jsonl_output_dir",
type=str,
default="jsonl_output",
help="Output directory of spliced dataset with jsonl format",
)
parser.add_argument(
"--data_arrow_output_dir",
type=str,
default="arrow_output",
help="Output directory of spliced dataset with arrow format",
)
parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
parser.add_argument(
"--num_samples_per_datafile",
type=int,
default=-1,
help="Number of samples to be generated from each data file. -1 denote all samples.",
)
args = parser.parse_args()
if args.num_spliced_dataset_bins >= 100000:
raise ValueError("Too many spliced divisions, must be smaller than 100000")
assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
assert not os.path.exists(
args.data_jsonl_output_dir
), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
assert not os.path.exists(
args.data_arrow_output_dir
), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
os.makedirs(args.data_jsonl_output_dir)
os.makedirs(args.data_arrow_output_dir)
# Prepare to all input datasets
input_data_paths = []
input_data_dirs = args.data_input_dirs.split(",")
for ds_dir in input_data_dirs:
ds_dir = os.path.abspath(ds_dir)
assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
input_data_paths.extend(ds_paths)
# Prepare to data splitting.
train_splits = []
split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
for i in range(0, 100, split_interval):
start = i
end = i + split_interval
if end > 100:
end = 100
train_splits.append(f"train[{start}%:{end}%]")
# Prepare the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, use_fast=False, trust_remote_code=True)
if os.path.exists(args.conversation_template_config):
chat_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
else:
chat_template_config = {
"system_message": "A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
} # Use default system message
if args.type == "preference":
if "stop_ids" not in chat_template_config:
# Ask the user to define stop_ids for PPO training
dummy_messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "Who made you?"},
{"role": "assistant", "content": "I am a chatbot trained by Colossal-AI."},
]
dummy_prompt = tokenizer.apply_chat_template(dummy_messages, tokenize=False)
tokenized = tokenizer(dummy_prompt, add_special_tokens=False)["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(tokenized, skip_special_tokens=False)
corresponding_str = [tokenizer.convert_tokens_to_string([token]) for token in tokens]
token_id_mapping = [{"token": s, "id": tokenized[i]} for i, s in enumerate(corresponding_str)]
stop_ids = input(
"For PPO, we recommend to provide stop_ids for the properly stop the generation during roll out stage. "
"stop_ids are the ids of repetitive pattern that indicate the end of the assistant's response. "
"Here is an example of formatted prompt and token-id mapping, you can set stop_ids by entering a list "
"of integers, separate by space, press `Enter` to end. Or you can press `Enter` without input if you are "
"not using PPO or you prefer to not set the stop_ids, in that case, stop_ids will be set to tokenizer.eos_token_id. "
f"\nPrompt:\n{dummy_prompt}\nToken-id Mapping:\n{token_id_mapping}\nstop_ids:"
)
if stop_ids == "":
chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
else:
try:
chat_template_config["stop_ids"] = [int(s) for s in stop_ids.split()]
except ValueError:
raise ValueError("Invalid input, please provide a list of integers.")
else:
# Set stop_ids to eos_token_id for other dataset types if not exist
if "stop_ids" not in chat_template_config:
chat_template_config["stop_ids"] = [tokenizer.eos_token_id]
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=chat_template_config, save_path=args.conversation_template_config
)
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
list_dataset = load_dataset(
path="json",
data_files=input_data_paths,
cache_dir=os.path.join(args.data_cache_dir, "raw"),
keep_in_memory=False,
split=train_splits,
num_proc=cpu_count(),
)
if args.type == "sft":
preparation_function = tokenize_sft
elif args.type == "prompt":
preparation_function = tokenize_prompt
elif args.type == "preference":
preparation_function = tokenize_rlhf
elif args.type == "kto":
preparation_function = tokenize_kto
else:
raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']")
for index, dataset in enumerate(list_dataset):
assert isinstance(dataset, dataset_dict.Dataset)
if len(dataset) == 0:
# Hack: Skip empty dataset. If dataset contains less than num_of_rank samples, some rank may have empty dataset and leads to error
continue
if args.num_samples_per_datafile > 0:
# limit the number of samples in each dataset
dataset = dataset.select(
random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset)))
)
logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
dataset = dataset.map(
function=preparation_function,
fn_kwargs={
"tokenizer": tokenizer,
"conversation_template": conversation_template,
"max_length": args.max_length,
},
keep_in_memory=False,
num_proc=min(len(dataset), cpu_count()),
)
if args.type == "kto":
filter_by = "completion"
elif args.type == "preference":
filter_by = "chosen_input_ids"
else:
filter_by = "input_ids"
dataset = dataset.filter(lambda data: data[filter_by] is not None)
# Save each jsonl spliced dataset.
output_index = "0" * (5 - len(str(index))) + str(index)
output_name = f"part-{output_index}"
output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
st = time.time()
with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
count = 0
for data_point in dataset:
if count % 500 == 0:
logger.info(f"processing {count} spliced data points for {fp_writer.name}")
count += 1
fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n")
logger.info(
f"Current file {fp_writer.name}; "
f"Data size: {len(dataset)}; "
f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
)
# Save each arrow spliced dataset
output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
logger.info(f"Start to save {output_arrow_path}")
dataset = load_dataset(
path="json",
data_files=[output_jsonl_path],
cache_dir=os.path.join(args.data_cache_dir, "tokenized"),
keep_in_memory=False,
num_proc=cpu_count(),
split="train",
)
dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count()))
if __name__ == "__main__":
main()