mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
176 lines
5.8 KiB
176 lines
5.8 KiB
from colossalai.logging import get_dist_logger |
|
|
|
from .bert_dataset import BertDataset |
|
from .blendable_dataset import BlendableDataset |
|
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ |
|
|
|
DSET_TYPE_BERT = "standard_bert" |
|
DSET_TYPE_ICT = "ict" |
|
DSET_TYPE_T5 = "t5" |
|
|
|
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] |
|
|
|
|
|
def _build_train_valid_test_datasets( |
|
data_prefix, |
|
data_impl, |
|
splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, |
|
masked_lm_prob, |
|
short_seq_prob, |
|
seed, |
|
skip_warmup, |
|
binary_head, |
|
dataset_type="standard_bert", |
|
): |
|
if dataset_type not in DSET_TYPES: |
|
raise ValueError("Invalid dataset_type: ", dataset_type) |
|
|
|
# Indexed dataset. |
|
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) |
|
|
|
# Get start and end indices of train/valid/train into doc-idx |
|
# Note that doc-idx is designed to be num-docs + 1 so we can |
|
# easily iterate over it. |
|
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 |
|
splits = get_train_valid_test_split_(splits_string, total_num_of_documents) |
|
|
|
logger = get_dist_logger() |
|
|
|
# Print stats about the splits. |
|
logger.info("\n > dataset split:", ranks=[0]) |
|
|
|
def print_split_stats(name, index): |
|
start_index = indexed_dataset.doc_idx[splits[index]] |
|
end_index = indexed_dataset.doc_idx[splits[index + 1]] |
|
logger.info( |
|
"\n {}:".format(name) |
|
+ "\n document indices in [{}, {}) total of {} documents".format( |
|
splits[index], splits[index + 1], splits[index + 1] - splits[index] |
|
) |
|
+ "\n sentence indices in [{}, {}) total of {} sentences".format( |
|
start_index, end_index, end_index - start_index |
|
), |
|
ranks=[0], |
|
) |
|
|
|
print_split_stats("train", 0) |
|
print_split_stats("validation", 1) |
|
print_split_stats("test", 2) |
|
|
|
def build_dataset(index, name): |
|
dataset = None |
|
if splits[index + 1] > splits[index]: |
|
# Get the pointer to the original doc-idx so we can set it later. |
|
doc_idx_ptr = indexed_dataset.get_doc_idx() |
|
# Slice the doc-idx |
|
start_index = splits[index] |
|
# Add +1 so we can index into the dataset to get the upper bound. |
|
end_index = splits[index + 1] + 1 |
|
# New doc_idx view. |
|
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) |
|
# Build the dataset accordingly. |
|
kwargs = dict( |
|
name=name, |
|
data_prefix=data_prefix, |
|
num_epochs=None, |
|
max_num_samples=train_valid_test_num_samples[index], |
|
max_seq_length=max_seq_length, |
|
seed=seed, |
|
) |
|
|
|
if dataset_type != DSET_TYPE_BERT: |
|
raise NotImplementedError("Only BERT dataset is supported") |
|
else: |
|
dataset = BertDataset( |
|
indexed_dataset=indexed_dataset, |
|
masked_lm_prob=masked_lm_prob, |
|
short_seq_prob=short_seq_prob, |
|
binary_head=binary_head, |
|
**kwargs, |
|
) |
|
|
|
# Set the original pointer so dataset remains the main dataset. |
|
indexed_dataset.set_doc_idx(doc_idx_ptr) |
|
# Checks. |
|
assert indexed_dataset.doc_idx[0] == 0 |
|
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) |
|
return dataset |
|
|
|
train_dataset = build_dataset(0, "train") |
|
valid_dataset = build_dataset(1, "valid") |
|
test_dataset = build_dataset(2, "test") |
|
|
|
return (train_dataset, valid_dataset, test_dataset) |
|
|
|
|
|
def build_train_valid_test_datasets( |
|
data_prefix, |
|
data_impl, |
|
splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, |
|
masked_lm_prob, |
|
short_seq_prob, |
|
seed, |
|
skip_warmup, |
|
binary_head, |
|
dataset_type="standard_bert", |
|
): |
|
if len(data_prefix) == 1: |
|
return _build_train_valid_test_datasets( |
|
data_prefix[0], |
|
data_impl, |
|
splits_string, |
|
train_valid_test_num_samples, |
|
max_seq_length, |
|
masked_lm_prob, |
|
short_seq_prob, |
|
seed, |
|
skip_warmup, |
|
binary_head, |
|
dataset_type=dataset_type, |
|
) |
|
# Blending dataset. |
|
# Parse the values. |
|
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) |
|
prefixes, weights, datasets_train_valid_test_num_samples = output |
|
|
|
# Build individual datasets. |
|
train_datasets = [] |
|
valid_datasets = [] |
|
test_datasets = [] |
|
for i in range(len(prefixes)): |
|
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( |
|
prefixes[i], |
|
data_impl, |
|
splits_string, |
|
datasets_train_valid_test_num_samples[i], |
|
max_seq_length, |
|
masked_lm_prob, |
|
short_seq_prob, |
|
seed, |
|
skip_warmup, |
|
binary_head, |
|
dataset_type=dataset_type, |
|
) |
|
if train_ds: |
|
train_datasets.append(train_ds) |
|
if valid_ds: |
|
valid_datasets.append(valid_ds) |
|
if test_ds: |
|
test_datasets.append(test_ds) |
|
|
|
# Blend. |
|
blending_train_dataset = None |
|
if train_datasets: |
|
blending_train_dataset = BlendableDataset(train_datasets, weights) |
|
blending_valid_dataset = None |
|
if valid_datasets: |
|
blending_valid_dataset = BlendableDataset(valid_datasets, weights) |
|
blending_test_dataset = None |
|
if test_datasets: |
|
blending_test_dataset = BlendableDataset(test_datasets, weights) |
|
|
|
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
|
|