from .blendable_dataset import BlendableDataset from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ from .bert_dataset import BertDataset from colossalai.logging import get_dist_logger DSET_TYPE_BERT = 'standard_bert' DSET_TYPE_ICT = 'ict' DSET_TYPE_T5 = 't5' DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, dataset_type='standard_bert'): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can # easily iterate over it. total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 splits = get_train_valid_test_split_(splits_string, total_num_of_documents) logger = get_dist_logger() # Print stats about the splits. logger.info('\n > dataset split:', ranks=[0]) def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] logger.info('\n {}:'.format(name) + '\n document indices in [{}, {}) total of {} documents'.format( splits[index], splits[index + 1], splits[index + 1] - splits[index]) + '\n sentence indices in [{}, {}) total of {} sentences'.format( start_index, end_index, end_index - start_index), ranks=[0]) print_split_stats('train', 0) print_split_stats('validation', 1) print_split_stats('test', 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. doc_idx_ptr = indexed_dataset.get_doc_idx() # Slice the doc-idx start_index = splits[index] # Add +1 so we can index into the dataset to get the upper bound. end_index = splits[index + 1] + 1 # New doc_idx view. indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) # Build the dataset accordingly. kwargs = dict( name=name, data_prefix=data_prefix, num_epochs=None, max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, seed=seed, ) if dataset_type != DSET_TYPE_BERT: raise NotImplementedError("Only BERT dataset is supported") else: dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, binary_head=binary_head, **kwargs ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 assert indexed_dataset.doc_idx.shape[0] == \ (total_num_of_documents + 1) return dataset train_dataset = build_dataset(0, 'train') valid_dataset = build_dataset(1, 'valid') test_dataset = build_dataset(2, 'test') return (train_dataset, valid_dataset, test_dataset) def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, dataset_type='standard_bert'): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, dataset_type=dataset_type) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, binary_head, dataset_type=dataset_type) if train_ds: train_datasets.append(train_ds) if valid_ds: valid_datasets.append(valid_ds) if test_ds: test_datasets.append(test_ds) # Blend. blending_train_dataset = None if train_datasets: blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = None if valid_datasets: blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = None if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)