mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.4 KiB
53 lines
1.4 KiB
import torch
|
|
|
|
|
|
class DummyDataloader:
|
|
def __init__(self, batch_size, vocab_size, seq_length):
|
|
self.batch_size = batch_size
|
|
self.vocab_size = vocab_size
|
|
self.seq_length = seq_length
|
|
self.step = 0
|
|
|
|
def generate(self):
|
|
tokens = torch.randint(
|
|
low=0,
|
|
high=self.vocab_size,
|
|
size=(
|
|
self.batch_size,
|
|
self.seq_length,
|
|
),
|
|
)
|
|
types = torch.randint(
|
|
low=0,
|
|
high=3,
|
|
size=(
|
|
self.batch_size,
|
|
self.seq_length,
|
|
),
|
|
)
|
|
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
|
|
loss_mask = torch.randint(
|
|
low=0,
|
|
high=2,
|
|
size=(
|
|
self.batch_size,
|
|
self.seq_length,
|
|
),
|
|
)
|
|
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
|
|
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
|
|
return dict(
|
|
text=tokens,
|
|
types=types,
|
|
is_random=sentence_order,
|
|
loss_mask=loss_mask,
|
|
labels=lm_labels,
|
|
padding_mask=padding_mask,
|
|
)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
return self.generate()
|