You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/tutorial/sequence_parallel/data/datasets/builder.py

177 lines
5.8 KiB

from colossalai.logging import get_dist_logger
from .bert_dataset import BertDataset
from .blendable_dataset import BlendableDataset
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
DSET_TYPE_BERT = "standard_bert"
DSET_TYPE_ICT = "ict"
DSET_TYPE_T5 = "t5"
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is designed to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
logger = get_dist_logger()
# Print stats about the splits.
logger.info("\n > dataset split:", ranks=[0])
def print_split_stats(name, index):
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logger.info(
"\n {}:".format(name)
+ "\n document indices in [{}, {}) total of {} documents".format(
splits[index], splits[index + 1], splits[index + 1] - splits[index]
)
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
start_index, end_index, end_index - start_index
),
ranks=[0],
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type != DSET_TYPE_BERT:
raise NotImplementedError("Only BERT dataset is supported")
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs,
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
return (train_dataset, valid_dataset, test_dataset)
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)