InternLM/internlm/data/dummy_dataset.py

45 lines
1.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numpy as np
from torch.utils.data import Dataset
class RandomDataset(Dataset):
"""
RandomDataset for generating random dataset.
Args:
num_samples (int): The number of samples to generate.
max_len (int): The maximum length of each sample.
"""
def __init__(self, num_samples=10000, max_len=1024) -> None:
super().__init__()
rng = np.random.RandomState(1999)
max_num = rng.randint(1, 30, size=(num_samples,))
rep_num = rng.randint(10, 200, size=(num_samples,))
data = []
lengths = []
for n, r in zip(max_num, rep_num):
d = list(range(n)) * r
d = [n, r] + d
d = d[:max_len]
data.append(d)
lengths.append(len(d))
self.data = data
self.max_len = max_len
self.lengths = np.array(lengths, dtype=int)
def __getitem__(self, index):
d = self.data[index]
input_ids = np.array(d, dtype=int)
return {"tokens": list(input_ids), "type_id": 0}
def get_dataset_name(self):
return "dummy_path/dummy_lang/dummy_ds/train.bin"
def __len__(self):
return len(self.data)