You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/tutorial/sequence_parallel/data/dummy_dataloader.py

53 lines
1.4 KiB

import torch
class DummyDataloader:
def __init__(self, batch_size, vocab_size, seq_length):
self.batch_size = batch_size
self.vocab_size = vocab_size
self.seq_length = seq_length
self.step = 0
def generate(self):
tokens = torch.randint(
low=0,
high=self.vocab_size,
size=(
self.batch_size,
self.seq_length,
),
)
types = torch.randint(
low=0,
high=3,
size=(
self.batch_size,
self.seq_length,
),
)
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
loss_mask = torch.randint(
low=0,
high=2,
size=(
self.batch_size,
self.seq_length,
),
)
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
return dict(
text=tokens,
types=types,
is_random=sentence_order,
loss_mask=loss_mask,
labels=lm_labels,
padding_mask=padding_mask,
)
def __iter__(self):
return self
def __next__(self):
return self.generate()