mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] added missing dummy dataloader (#1944)
parent
c6ea65011f
commit
de56b563b9
|
@ -1 +1 @@
|
|||
data/
|
||||
./data/
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
|
||||
def __init__(self, batch_size, vocab_size, seq_length):
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
self.seq_length = seq_length
|
||||
self.step = 0
|
||||
|
||||
def generate(self):
|
||||
tokens = torch.randint(low=0, high=self.vocab_size, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
types = torch.randint(low=0, high=3, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
|
||||
loss_mask = torch.randint(low=0, high=2, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
|
||||
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
|
||||
return dict(text=tokens,
|
||||
types=types,
|
||||
is_random=sentence_order,
|
||||
loss_mask=loss_mask,
|
||||
labels=lm_labels,
|
||||
padding_mask=padding_mask)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.generate()
|
Loading…
Reference in New Issue