mirror of https://github.com/hpcaitech/ColossalAI
25 lines
781 B
Python
25 lines
781 B
Python
import torch
|
|||
from torch.utils.data import Dataset
|
|||
|
|||
from colossalai.accelerator import get_accelerator
|
|||
|
|||
|
|||
class RandomDataset(Dataset):
|
|||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
|||
self.num_samples = num_samples
|
|||
self.max_length = max_length
|
|||
self.input_ids = torch.randint(
|
|||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
|||
)
|
|||
self.attention_mask = torch.ones_like(self.input_ids)
|
|||
|
|||
def __len__(self):
|
|||
return self.num_samples
|
|||
|
|||
def __getitem__(self, idx):
|
|||
return {
|
|||
"input_ids": self.input_ids[idx],
|
|||
"attention_mask": self.attention_mask[idx],
|
|||
"labels": self.input_ids[idx],
|
|||
}
|