mirror of https://github.com/InternLM/InternLM
45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class RandomDataset(Dataset):
|
|
"""
|
|
RandomDataset for generating random dataset.
|
|
|
|
Args:
|
|
num_samples (int): The number of samples to generate.
|
|
max_len (int): The maximum length of each sample.
|
|
|
|
"""
|
|
|
|
def __init__(self, num_samples=10000, max_len=1024) -> None:
|
|
super().__init__()
|
|
rng = np.random.RandomState(1999)
|
|
max_num = rng.randint(1, 30, size=(num_samples,))
|
|
rep_num = rng.randint(10, 200, size=(num_samples,))
|
|
data = []
|
|
lengths = []
|
|
for n, r in zip(max_num, rep_num):
|
|
d = list(range(n)) * r
|
|
d = [n, r] + d
|
|
d = d[:max_len]
|
|
data.append(d)
|
|
lengths.append(len(d))
|
|
self.data = data
|
|
self.max_len = max_len
|
|
self.lengths = np.array(lengths, dtype=int)
|
|
|
|
def __getitem__(self, index):
|
|
d = self.data[index]
|
|
input_ids = np.array(d, dtype=int)
|
|
return {"tokens": list(input_ids), "type_id": 0}
|
|
|
|
def get_dataset_name(self):
|
|
return "dummy_path/dummy_lang/dummy_ds/train.bin"
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|