mirror of https://github.com/hpcaitech/ColossalAI
593 lines
22 KiB
Python
593 lines
22 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
|
||
|
# Most of the code here has been copied from:
|
||
|
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||
|
# with some modifications.
|
||
|
|
||
|
import math
|
||
|
import time
|
||
|
import collections
|
||
|
from colossalai.logging import get_dist_logger
|
||
|
import numpy as np
|
||
|
from .blendable_dataset import BlendableDataset
|
||
|
from .indexed_dataset import make_dataset as make_indexed_dataset
|
||
|
|
||
|
DSET_TYPE_STD = 'standard_bert'
|
||
|
DSET_TYPE_ICT = 'ict'
|
||
|
|
||
|
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
|
||
|
|
||
|
|
||
|
def get_datasets_weights_and_num_samples(data_prefix,
|
||
|
train_valid_test_num_samples):
|
||
|
|
||
|
# The data prefix should be in the format of:
|
||
|
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||
|
assert len(data_prefix) % 2 == 0
|
||
|
num_datasets = len(data_prefix) // 2
|
||
|
weights = [0]*num_datasets
|
||
|
prefixes = [0]*num_datasets
|
||
|
for i in range(num_datasets):
|
||
|
weights[i] = float(data_prefix[2*i])
|
||
|
prefixes[i] = (data_prefix[2*i+1]).strip()
|
||
|
# Normalize weights
|
||
|
weight_sum = 0.0
|
||
|
for weight in weights:
|
||
|
weight_sum += weight
|
||
|
assert weight_sum > 0.0
|
||
|
weights = [weight / weight_sum for weight in weights]
|
||
|
|
||
|
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
|
||
|
# not uniformly distribute the number of samples, we still have
|
||
|
# samples left to feed to the network.
|
||
|
datasets_train_valid_test_num_samples = []
|
||
|
for weight in weights:
|
||
|
datasets_train_valid_test_num_samples.append(
|
||
|
[int(math.ceil(val * weight * 1.005))
|
||
|
for val in train_valid_test_num_samples])
|
||
|
|
||
|
return prefixes, weights, datasets_train_valid_test_num_samples
|
||
|
|
||
|
|
||
|
def compile_helper():
|
||
|
"""Compile helper function ar runtime. Make sure this
|
||
|
is invoked on a single process."""
|
||
|
import os
|
||
|
import subprocess
|
||
|
path = os.path.abspath(os.path.dirname(__file__))
|
||
|
ret = subprocess.run(['make', '-C', path])
|
||
|
if ret.returncode != 0:
|
||
|
print("Making C++ dataset helpers module failed, exiting.")
|
||
|
import sys
|
||
|
sys.exit(1)
|
||
|
|
||
|
|
||
|
def get_a_and_b_segments(sample, np_rng):
|
||
|
"""Divide sample into a and b segments."""
|
||
|
|
||
|
# Number of sentences in the sample.
|
||
|
n_sentences = len(sample)
|
||
|
# Make sure we always have two sentences.
|
||
|
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
||
|
|
||
|
# First part:
|
||
|
# `a_end` is how many sentences go into the `A`.
|
||
|
a_end = 1
|
||
|
if n_sentences >= 3:
|
||
|
# Note that randin in numpy is exclusive.
|
||
|
a_end = np_rng.randint(1, n_sentences)
|
||
|
tokens_a = []
|
||
|
for j in range(a_end):
|
||
|
tokens_a.extend(sample[j])
|
||
|
|
||
|
# Second part:
|
||
|
tokens_b = []
|
||
|
for j in range(a_end, n_sentences):
|
||
|
tokens_b.extend(sample[j])
|
||
|
|
||
|
# Random next:
|
||
|
is_next_random = False
|
||
|
if np_rng.random() < 0.5:
|
||
|
is_next_random = True
|
||
|
tokens_a, tokens_b = tokens_b, tokens_a
|
||
|
|
||
|
return tokens_a, tokens_b, is_next_random
|
||
|
|
||
|
|
||
|
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||
|
"""Truncates a pair of sequences to a maximum sequence length."""
|
||
|
#print(len_a, len_b, max_num_tokens)
|
||
|
assert len_a > 0
|
||
|
if len_a + len_b <= max_num_tokens:
|
||
|
return False
|
||
|
while len_a + len_b > max_num_tokens:
|
||
|
if len_a > len_b:
|
||
|
len_a -= 1
|
||
|
tokens = tokens_a
|
||
|
else:
|
||
|
len_b -= 1
|
||
|
tokens = tokens_b
|
||
|
if np_rng.random() < 0.5:
|
||
|
del tokens[0]
|
||
|
else:
|
||
|
tokens.pop()
|
||
|
return True
|
||
|
|
||
|
|
||
|
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||
|
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
|
||
|
|
||
|
tokens = []
|
||
|
tokentypes = []
|
||
|
# [CLS].
|
||
|
tokens.append(cls_id)
|
||
|
tokentypes.append(0)
|
||
|
# Segment A.
|
||
|
for token in tokens_a:
|
||
|
tokens.append(token)
|
||
|
tokentypes.append(0)
|
||
|
# [SEP].
|
||
|
tokens.append(sep_id)
|
||
|
tokentypes.append(0)
|
||
|
# Segment B.
|
||
|
for token in tokens_b:
|
||
|
tokens.append(token)
|
||
|
tokentypes.append(1)
|
||
|
if tokens_b:
|
||
|
# [SEP].
|
||
|
tokens.append(sep_id)
|
||
|
tokentypes.append(1)
|
||
|
|
||
|
return tokens, tokentypes
|
||
|
|
||
|
|
||
|
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||
|
["index", "label"])
|
||
|
|
||
|
|
||
|
def is_start_piece(piece):
|
||
|
"""Check if the current word piece is the starting piece (BERT)."""
|
||
|
# When a word has been split into
|
||
|
# WordPieces, the first token does not have any marker and any subsequence
|
||
|
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||
|
# append it to the previous set of word indexes.
|
||
|
return not piece.startswith("##")
|
||
|
|
||
|
|
||
|
def create_masked_lm_predictions(tokens,
|
||
|
vocab_id_list, vocab_id_to_token_dict,
|
||
|
masked_lm_prob,
|
||
|
cls_id, sep_id, mask_id,
|
||
|
max_predictions_per_seq,
|
||
|
np_rng,
|
||
|
max_ngrams=3,
|
||
|
do_whole_word_mask=True,
|
||
|
favor_longer_ngram=False,
|
||
|
do_permutation=False):
|
||
|
"""Creates the predictions for the masked LM objective.
|
||
|
Note: Tokens here are vocab ids and not text tokens."""
|
||
|
|
||
|
cand_indexes = []
|
||
|
# Note(mingdachen): We create a list for recording if the piece is
|
||
|
# the starting piece of current token, where 1 means true, so that
|
||
|
# on-the-fly whole word masking is possible.
|
||
|
token_boundary = [0] * len(tokens)
|
||
|
|
||
|
for (i, token) in enumerate(tokens):
|
||
|
if token == cls_id or token == sep_id:
|
||
|
token_boundary[i] = 1
|
||
|
continue
|
||
|
# Whole Word Masking means that if we mask all of the wordpieces
|
||
|
# corresponding to an original word.
|
||
|
#
|
||
|
# Note that Whole Word Masking does *not* change the training code
|
||
|
# at all -- we still predict each WordPiece independently, softmaxed
|
||
|
# over the entire vocabulary.
|
||
|
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||
|
not is_start_piece(vocab_id_to_token_dict[token])):
|
||
|
cand_indexes[-1].append(i)
|
||
|
else:
|
||
|
cand_indexes.append([i])
|
||
|
if is_start_piece(vocab_id_to_token_dict[token]):
|
||
|
token_boundary[i] = 1
|
||
|
|
||
|
output_tokens = list(tokens)
|
||
|
|
||
|
masked_lm_positions = []
|
||
|
masked_lm_labels = []
|
||
|
|
||
|
if masked_lm_prob == 0:
|
||
|
return (output_tokens, masked_lm_positions,
|
||
|
masked_lm_labels, token_boundary)
|
||
|
|
||
|
num_to_predict = min(max_predictions_per_seq,
|
||
|
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||
|
|
||
|
# Note(mingdachen):
|
||
|
# By default, we set the probabilities to favor shorter ngram sequences.
|
||
|
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||
|
pvals = 1. / np.arange(1, max_ngrams + 1)
|
||
|
pvals /= pvals.sum(keepdims=True)
|
||
|
|
||
|
if favor_longer_ngram:
|
||
|
pvals = pvals[::-1]
|
||
|
|
||
|
ngram_indexes = []
|
||
|
for idx in range(len(cand_indexes)):
|
||
|
ngram_index = []
|
||
|
for n in ngrams:
|
||
|
ngram_index.append(cand_indexes[idx:idx + n])
|
||
|
ngram_indexes.append(ngram_index)
|
||
|
|
||
|
np_rng.shuffle(ngram_indexes)
|
||
|
|
||
|
masked_lms = []
|
||
|
covered_indexes = set()
|
||
|
for cand_index_set in ngram_indexes:
|
||
|
if len(masked_lms) >= num_to_predict:
|
||
|
break
|
||
|
if not cand_index_set:
|
||
|
continue
|
||
|
# Note(mingdachen):
|
||
|
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||
|
for index_set in cand_index_set[0]:
|
||
|
for index in index_set:
|
||
|
if index in covered_indexes:
|
||
|
continue
|
||
|
|
||
|
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
||
|
p=pvals[:len(cand_index_set)] /
|
||
|
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||
|
index_set = sum(cand_index_set[n - 1], [])
|
||
|
n -= 1
|
||
|
# Note(mingdachen):
|
||
|
# Repeatedly looking for a candidate that does not exceed the
|
||
|
# maximum number of predictions by trying shorter ngrams.
|
||
|
while len(masked_lms) + len(index_set) > num_to_predict:
|
||
|
if n == 0:
|
||
|
break
|
||
|
index_set = sum(cand_index_set[n - 1], [])
|
||
|
n -= 1
|
||
|
# If adding a whole-word mask would exceed the maximum number of
|
||
|
# predictions, then just skip this candidate.
|
||
|
if len(masked_lms) + len(index_set) > num_to_predict:
|
||
|
continue
|
||
|
is_any_index_covered = False
|
||
|
for index in index_set:
|
||
|
if index in covered_indexes:
|
||
|
is_any_index_covered = True
|
||
|
break
|
||
|
if is_any_index_covered:
|
||
|
continue
|
||
|
for index in index_set:
|
||
|
covered_indexes.add(index)
|
||
|
|
||
|
masked_token = None
|
||
|
# 80% of the time, replace with [MASK]
|
||
|
if np_rng.random() < 0.8:
|
||
|
masked_token = mask_id
|
||
|
else:
|
||
|
# 10% of the time, keep original
|
||
|
if np_rng.random() < 0.5:
|
||
|
masked_token = tokens[index]
|
||
|
# 10% of the time, replace with random word
|
||
|
else:
|
||
|
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
|
||
|
|
||
|
output_tokens[index] = masked_token
|
||
|
|
||
|
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||
|
assert len(masked_lms) <= num_to_predict
|
||
|
|
||
|
np_rng.shuffle(ngram_indexes)
|
||
|
|
||
|
select_indexes = set()
|
||
|
if do_permutation:
|
||
|
for cand_index_set in ngram_indexes:
|
||
|
if len(select_indexes) >= num_to_predict:
|
||
|
break
|
||
|
if not cand_index_set:
|
||
|
continue
|
||
|
# Note(mingdachen):
|
||
|
# Skip current piece if they are covered in lm masking or previous ngrams.
|
||
|
for index_set in cand_index_set[0]:
|
||
|
for index in index_set:
|
||
|
if index in covered_indexes or index in select_indexes:
|
||
|
continue
|
||
|
|
||
|
n = np.random.choice(ngrams[:len(cand_index_set)],
|
||
|
p=pvals[:len(cand_index_set)] /
|
||
|
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||
|
index_set = sum(cand_index_set[n - 1], [])
|
||
|
n -= 1
|
||
|
|
||
|
while len(select_indexes) + len(index_set) > num_to_predict:
|
||
|
if n == 0:
|
||
|
break
|
||
|
index_set = sum(cand_index_set[n - 1], [])
|
||
|
n -= 1
|
||
|
# If adding a whole-word mask would exceed the maximum number of
|
||
|
# predictions, then just skip this candidate.
|
||
|
if len(select_indexes) + len(index_set) > num_to_predict:
|
||
|
continue
|
||
|
is_any_index_covered = False
|
||
|
for index in index_set:
|
||
|
if index in covered_indexes or index in select_indexes:
|
||
|
is_any_index_covered = True
|
||
|
break
|
||
|
if is_any_index_covered:
|
||
|
continue
|
||
|
for index in index_set:
|
||
|
select_indexes.add(index)
|
||
|
assert len(select_indexes) <= num_to_predict
|
||
|
|
||
|
select_indexes = sorted(select_indexes)
|
||
|
permute_indexes = list(select_indexes)
|
||
|
np_rng.shuffle(permute_indexes)
|
||
|
orig_token = list(output_tokens)
|
||
|
|
||
|
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
||
|
output_tokens[src_i] = orig_token[tgt_i]
|
||
|
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
||
|
|
||
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||
|
|
||
|
for p in masked_lms:
|
||
|
masked_lm_positions.append(p.index)
|
||
|
masked_lm_labels.append(p.label)
|
||
|
|
||
|
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||
|
|
||
|
|
||
|
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||
|
masked_labels, pad_id, max_seq_length):
|
||
|
"""Pad sequences and convert them to numpy."""
|
||
|
|
||
|
# Some checks.
|
||
|
num_tokens = len(tokens)
|
||
|
padding_length = max_seq_length - num_tokens
|
||
|
assert padding_length >= 0
|
||
|
assert len(tokentypes) == num_tokens
|
||
|
assert len(masked_positions) == len(masked_labels)
|
||
|
|
||
|
# Tokens and token types.
|
||
|
filler = [pad_id] * padding_length
|
||
|
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||
|
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||
|
|
||
|
# Padding mask.
|
||
|
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
||
|
dtype=np.int64)
|
||
|
|
||
|
# Lables and loss mask.
|
||
|
labels = [-1] * max_seq_length
|
||
|
loss_mask = [0] * max_seq_length
|
||
|
for i in range(len(masked_positions)):
|
||
|
assert masked_positions[i] < num_tokens
|
||
|
labels[masked_positions[i]] = masked_labels[i]
|
||
|
loss_mask[masked_positions[i]] = 1
|
||
|
labels_np = np.array(labels, dtype=np.int64)
|
||
|
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||
|
|
||
|
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||
|
|
||
|
|
||
|
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||
|
train_valid_test_num_samples,
|
||
|
max_seq_length, masked_lm_prob,
|
||
|
short_seq_prob, seed, skip_warmup,
|
||
|
binary_head,
|
||
|
dataset_type='standard_bert'):
|
||
|
|
||
|
if len(data_prefix) == 1:
|
||
|
return _build_train_valid_test_datasets(data_prefix[0],
|
||
|
data_impl, splits_string,
|
||
|
train_valid_test_num_samples,
|
||
|
max_seq_length, masked_lm_prob,
|
||
|
short_seq_prob, seed,
|
||
|
skip_warmup,
|
||
|
binary_head,
|
||
|
dataset_type=dataset_type)
|
||
|
# Blending dataset.
|
||
|
# Parse the values.
|
||
|
output = get_datasets_weights_and_num_samples(data_prefix,
|
||
|
train_valid_test_num_samples)
|
||
|
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||
|
|
||
|
# Build individual datasets.
|
||
|
train_datasets = []
|
||
|
valid_datasets = []
|
||
|
test_datasets = []
|
||
|
for i in range(len(prefixes)):
|
||
|
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||
|
prefixes[i], data_impl, splits_string,
|
||
|
datasets_train_valid_test_num_samples[i],
|
||
|
max_seq_length, masked_lm_prob, short_seq_prob,
|
||
|
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||
|
if train_ds:
|
||
|
train_datasets.append(train_ds)
|
||
|
if valid_ds:
|
||
|
valid_datasets.append(valid_ds)
|
||
|
if test_ds:
|
||
|
test_datasets.append(test_ds)
|
||
|
|
||
|
# Blend.
|
||
|
blending_train_dataset = None
|
||
|
if train_datasets:
|
||
|
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||
|
blending_valid_dataset = None
|
||
|
if valid_datasets:
|
||
|
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||
|
blending_test_dataset = None
|
||
|
if test_datasets:
|
||
|
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||
|
|
||
|
return (blending_train_dataset, blending_valid_dataset,
|
||
|
blending_test_dataset)
|
||
|
|
||
|
|
||
|
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||
|
train_valid_test_num_samples,
|
||
|
max_seq_length, masked_lm_prob,
|
||
|
short_seq_prob, seed, skip_warmup,
|
||
|
binary_head,
|
||
|
dataset_type='standard_bert'):
|
||
|
logger = get_dist_logger()
|
||
|
|
||
|
if dataset_type not in DSET_TYPES:
|
||
|
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||
|
|
||
|
# Indexed dataset.
|
||
|
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||
|
data_impl,
|
||
|
skip_warmup)
|
||
|
|
||
|
if dataset_type == DSET_TYPE_ICT:
|
||
|
args = get_args()
|
||
|
title_dataset = get_indexed_dataset_(args.titles_data_path,
|
||
|
data_impl,
|
||
|
skip_warmup)
|
||
|
|
||
|
# Get start and end indices of train/valid/train into doc-idx
|
||
|
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||
|
# easily iterate over it.
|
||
|
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
|
||
|
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||
|
|
||
|
# Print stats about the splits.
|
||
|
logger.info('\n > dataset split:')
|
||
|
|
||
|
def print_split_stats(name, index):
|
||
|
start_index = indexed_dataset.doc_idx[splits[index]]
|
||
|
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||
|
logger.info('\n {}:'.format(name) +
|
||
|
'\n document indices in [{}, {}) total of {} documents'.format(
|
||
|
splits[index],
|
||
|
splits[index + 1],
|
||
|
splits[index + 1] - splits[index]) +
|
||
|
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||
|
start_index,
|
||
|
end_index,
|
||
|
end_index - start_index),
|
||
|
ranks=[0])
|
||
|
print_split_stats('train', 0)
|
||
|
print_split_stats('validation', 1)
|
||
|
print_split_stats('test', 2)
|
||
|
|
||
|
def build_dataset(index, name):
|
||
|
from .bert_dataset import BertDataset
|
||
|
dataset = None
|
||
|
if splits[index + 1] > splits[index]:
|
||
|
# Get the pointer to the original doc-idx so we can set it later.
|
||
|
doc_idx_ptr = indexed_dataset.get_doc_idx()
|
||
|
# Slice the doc-idx
|
||
|
start_index = splits[index]
|
||
|
# Add +1 so we can index into the dataset to get the upper bound.
|
||
|
end_index = splits[index + 1] + 1
|
||
|
# New doc_idx view.
|
||
|
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
|
||
|
# Build the dataset accordingly.
|
||
|
kwargs = dict(
|
||
|
name=name,
|
||
|
data_prefix=data_prefix,
|
||
|
num_epochs=None,
|
||
|
max_num_samples=train_valid_test_num_samples[index],
|
||
|
max_seq_length=max_seq_length,
|
||
|
seed=seed,
|
||
|
binary_head=binary_head
|
||
|
)
|
||
|
|
||
|
if dataset_type == DSET_TYPE_ICT:
|
||
|
args = get_args()
|
||
|
dataset = ICTDataset(
|
||
|
block_dataset=indexed_dataset,
|
||
|
title_dataset=title_dataset,
|
||
|
query_in_block_prob=args.query_in_block_prob,
|
||
|
use_one_sent_docs=args.use_one_sent_docs,
|
||
|
**kwargs
|
||
|
)
|
||
|
else:
|
||
|
dataset = BertDataset(
|
||
|
indexed_dataset=indexed_dataset,
|
||
|
masked_lm_prob=masked_lm_prob,
|
||
|
short_seq_prob=short_seq_prob,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
# Set the original pointer so dataset remains the main dataset.
|
||
|
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||
|
# Checks.
|
||
|
assert indexed_dataset.doc_idx[0] == 0
|
||
|
assert indexed_dataset.doc_idx.shape[0] == \
|
||
|
(total_num_of_documents + 1)
|
||
|
return dataset
|
||
|
|
||
|
train_dataset = build_dataset(0, 'train')
|
||
|
valid_dataset = build_dataset(1, 'valid')
|
||
|
test_dataset = build_dataset(2, 'test')
|
||
|
|
||
|
return (train_dataset, valid_dataset, test_dataset)
|
||
|
|
||
|
|
||
|
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||
|
logger = get_dist_logger()
|
||
|
start_time = time.time()
|
||
|
indexed_dataset = make_indexed_dataset(data_prefix,
|
||
|
data_impl,
|
||
|
skip_warmup)
|
||
|
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||
|
logger.info('\n > building dataset index ...', ranks=[0])
|
||
|
logger.info('\n > finished creating indexed dataset in {:4f} '
|
||
|
'seconds'.format(time.time() - start_time), ranks=[0])
|
||
|
logger.info('\n > indexed dataset stats:' +
|
||
|
'\n number of documents: {}'.format(
|
||
|
indexed_dataset.doc_idx.shape[0] - 1) +
|
||
|
'\n number of sentences: {}'.format(
|
||
|
indexed_dataset.sizes.shape[0]),
|
||
|
ranks=[0]
|
||
|
)
|
||
|
|
||
|
return indexed_dataset
|
||
|
|
||
|
|
||
|
def get_train_valid_test_split_(splits_string, size):
|
||
|
""" Get dataset splits from comma or '/' separated string list."""
|
||
|
|
||
|
splits = []
|
||
|
if splits_string.find(',') != -1:
|
||
|
splits = [float(s) for s in splits_string.split(',')]
|
||
|
elif splits_string.find('/') != -1:
|
||
|
splits = [float(s) for s in splits_string.split('/')]
|
||
|
else:
|
||
|
splits = [float(splits_string)]
|
||
|
while len(splits) < 3:
|
||
|
splits.append(0.)
|
||
|
splits = splits[:3]
|
||
|
splits_sum = sum(splits)
|
||
|
assert splits_sum > 0.0
|
||
|
splits = [split / splits_sum for split in splits]
|
||
|
splits_index = [0]
|
||
|
for index, split in enumerate(splits):
|
||
|
splits_index.append(splits_index[index] +
|
||
|
int(round(split * float(size))))
|
||
|
diff = splits_index[-1] - size
|
||
|
for index in range(1, len(splits_index)):
|
||
|
splits_index[index] -= diff
|
||
|
assert len(splits_index) == 4
|
||
|
assert splits_index[-1] == size
|
||
|
return splits_index
|