mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
2 years ago
committed by
GitHub
2 changed files with 40 additions and 1 deletions
@ -0,0 +1,39 @@
|
||||
import torch |
||||
|
||||
|
||||
class DummyDataloader(): |
||||
|
||||
def __init__(self, batch_size, vocab_size, seq_length): |
||||
self.batch_size = batch_size |
||||
self.vocab_size = vocab_size |
||||
self.seq_length = seq_length |
||||
self.step = 0 |
||||
|
||||
def generate(self): |
||||
tokens = torch.randint(low=0, high=self.vocab_size, size=( |
||||
self.batch_size, |
||||
self.seq_length, |
||||
)) |
||||
types = torch.randint(low=0, high=3, size=( |
||||
self.batch_size, |
||||
self.seq_length, |
||||
)) |
||||
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) |
||||
loss_mask = torch.randint(low=0, high=2, size=( |
||||
self.batch_size, |
||||
self.seq_length, |
||||
)) |
||||
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) |
||||
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) |
||||
return dict(text=tokens, |
||||
types=types, |
||||
is_random=sentence_order, |
||||
loss_mask=loss_mask, |
||||
labels=lm_labels, |
||||
padding_mask=padding_mask) |
||||
|
||||
def __iter__(self): |
||||
return self |
||||
|
||||
def __next__(self): |
||||
return self.generate() |
Loading…
Reference in new issue