[tutorial] added missing dummy dataloader (#1944)

pull/1946/head
Frank Lee 2022-11-14 18:09:03 +08:00 committed by GitHub
parent c6ea65011f
commit de56b563b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 1 deletions

View File

@ -1 +1 @@
data/
./data/

View File

@ -0,0 +1,39 @@
import torch
class DummyDataloader():
def __init__(self, batch_size, vocab_size, seq_length):
self.batch_size = batch_size
self.vocab_size = vocab_size
self.seq_length = seq_length
self.step = 0
def generate(self):
tokens = torch.randint(low=0, high=self.vocab_size, size=(
self.batch_size,
self.seq_length,
))
types = torch.randint(low=0, high=3, size=(
self.batch_size,
self.seq_length,
))
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
loss_mask = torch.randint(low=0, high=2, size=(
self.batch_size,
self.seq_length,
))
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
return dict(text=tokens,
types=types,
is_random=sentence_order,
loss_mask=loss_mask,
labels=lm_labels,
padding_mask=padding_mask)
def __iter__(self):
return self
def __next__(self):
return self.generate()