import os import random import time from concurrent.futures import ProcessPoolExecutor import h5py import numpy as np import torch import torch.distributed as dist from bert_dataset_provider import BertDatasetProviderInterface from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler # Workaround because python functions are not picklable class WorkerInitObj(object): def __init__(self, seed): self.seed = seed def __call__(self, id): np.random.seed(seed=self.seed + id) random.seed(self.seed + id) def create_pretraining_dataset( input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, data_sampler ): train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) train_dataloader = DataLoader( train_data, sampler=data_sampler(train_data), batch_size=train_batch_size, num_workers=num_workers, worker_init_fn=worker_init, pin_memory=True, ) return train_dataloader, len(train_data) class pretraining_dataset(Dataset): def __init__(self, input_file, max_predictions_per_seq): self.input_file = input_file self.max_predictions_per_seq = max_predictions_per_seq f = h5py.File(input_file, "r") keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions"] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): [input_ids, input_mask, segment_ids, masked_lm_labels] = [ torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) ] return [input_ids, input_mask, segment_ids, masked_lm_labels] class NvidiaBertDatasetProvider(BertDatasetProviderInterface): def __init__(self, args, evaluate=False): self.num_workers = args.num_workers self.max_seq_length = args.max_seq_length self.max_predictions_per_seq = args.max_predictions_per_seq self.gradient_accumulation_steps = args.gradient_accumulation_steps if not evaluate: self.train_micro_batch_size_per_gpu = args.train_micro_batch_size_per_gpu else: self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu self.logger = args.logger self.global_rank = dist.get_rank() self.world_size = dist.get_world_size() # Initialize dataset files if not evaluate: self.dataset_files = [ os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if os.path.isfile(os.path.join(args.data_path_prefix, f)) and "h5" in f ] else: self.dataset_files = [ os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and "h5" in f ] self.dataset_files.sort() # random.shuffle(self.dataset_files) self.num_files = len(self.dataset_files) # self.data_sampler = RandomSampler self.data_sampler = DistributedSampler self.worker_init = WorkerInitObj(args.seed + args.local_rank) self.dataset_future = None self.pool = ProcessPoolExecutor(1) self.data_file = None self.shuffle = True if self.global_rank == 0: self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}") def get_shard(self, index): start = time.time() if self.dataset_future is None: self.data_file = self._get_shard_file(index) self.train_dataloader, sample_count = create_pretraining_dataset( input_file=self.data_file, max_predictions_per_seq=self.max_predictions_per_seq, num_workers=self.num_workers, train_batch_size=self.train_micro_batch_size_per_gpu, worker_init=self.worker_init, data_sampler=self.data_sampler, ) else: self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) self.logger.info( f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." ) return self.train_dataloader, sample_count def release_shard(self): del self.train_dataloader self.pool.shutdown() def prefetch_shard(self, index): self.data_file = self._get_shard_file(index) self.dataset_future = self.pool.submit( create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, self.data_sampler, ) def get_batch(self, batch_iter): return batch_iter def prefetch_batch(self): pass def _get_shard_file(self, shard_index): file_index = self._get_shard_file_index(shard_index, self.global_rank) return self.dataset_files[file_index] def _get_shard_file_index(self, shard_index, global_rank): # if dist.is_initialized() and self.world_size > self.num_files: # remainder = self.world_size % self.num_files # file_index = (shard_index * self.world_size) + global_rank + ( # remainder * shard_index) # else: # file_index = shard_index * self.world_size + global_rank return shard_index % self.num_files def shuffle_dataset(self, epoch): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(self.num_files, generator=g).tolist() new_dataset = [self.dataset_files[i] for i in indices] self.dataset_files = new_dataset