import torch from torch.utils.data import Dataset from colossalai.accelerator import get_accelerator class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint( 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): return self.num_samples def __getitem__(self, idx): return { "input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx], "labels": self.input_ids[idx], }