diff --git a/examples/tutorial/.gitignore b/examples/tutorial/.gitignore index 8fce60300..f873b6a4a 100644 --- a/examples/tutorial/.gitignore +++ b/examples/tutorial/.gitignore @@ -1 +1 @@ -data/ +./data/ diff --git a/examples/tutorial/sequence_parallel/data/dummy_dataloader.py b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py new file mode 100644 index 000000000..faa90175c --- /dev/null +++ b/examples/tutorial/sequence_parallel/data/dummy_dataloader.py @@ -0,0 +1,39 @@ +import torch + + +class DummyDataloader(): + + def __init__(self, batch_size, vocab_size, seq_length): + self.batch_size = batch_size + self.vocab_size = vocab_size + self.seq_length = seq_length + self.step = 0 + + def generate(self): + tokens = torch.randint(low=0, high=self.vocab_size, size=( + self.batch_size, + self.seq_length, + )) + types = torch.randint(low=0, high=3, size=( + self.batch_size, + self.seq_length, + )) + sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) + loss_mask = torch.randint(low=0, high=2, size=( + self.batch_size, + self.seq_length, + )) + lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) + padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) + return dict(text=tokens, + types=types, + is_random=sentence_order, + loss_mask=loss_mask, + labels=lm_labels, + padding_mask=padding_mask) + + def __iter__(self): + return self + + def __next__(self): + return self.generate() \ No newline at end of file