mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
949 B
38 lines
949 B
from abc import abstractstaticmethod |
|
|
|
from colossal_eval.utils import jdump |
|
from torch.utils.data import Dataset |
|
|
|
from colossalai.logging import DistributedLogger |
|
|
|
|
|
class BaseDataset: |
|
""" |
|
Base class for dataset wrapper. |
|
|
|
Args: |
|
path: The path to the original dataset. |
|
logger: Logger for the dataset. |
|
""" |
|
|
|
def __init__(self, path, logger, *args, **kwargs): |
|
self.dataset = self.load(path, logger, *args, **kwargs) |
|
|
|
def save(self, save_path): |
|
"""Save the converted dataset""" |
|
jdump(self.dataset, save_path) |
|
|
|
@abstractstaticmethod |
|
def load(path, logger: DistributedLogger, *args, **kwargs): |
|
"""Load the original dataset and convert it into the inference dataset""" |
|
|
|
|
|
class DistributedDataset(Dataset): |
|
def __init__(self, data): |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
return self.data[idx]
|
|
|