import torch class DummyDataloader: def __init__(self, batch_size, vocab_size, seq_length): self.batch_size = batch_size self.vocab_size = vocab_size self.seq_length = seq_length self.step = 0 def generate(self): tokens = torch.randint( low=0, high=self.vocab_size, size=( self.batch_size, self.seq_length, ), ) types = torch.randint( low=0, high=3, size=( self.batch_size, self.seq_length, ), ) sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) loss_mask = torch.randint( low=0, high=2, size=( self.batch_size, self.seq_length, ), ) lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) return dict( text=tokens, types=types, is_random=sentence_order, loss_mask=loss_mask, labels=lm_labels, padding_mask=padding_mask, ) def __iter__(self): return self def __next__(self): return self.generate()