You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_data/test_cifar10_dataset.py

55 lines
1.3 KiB

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pathlib import Path
import pytest
from torchvision import transforms
from torch.utils.data import DataLoader
from colossalai.builder import build_dataset, build_transform
from colossalai.context import Config
TRAIN_DATA = dict(
dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True
),
dataloader=dict(batch_size=4, shuffle=True, num_workers=2),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize',
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)
]
)
@pytest.mark.cpu
def test_cifar10_dataset():
config = Config(TRAIN_DATA)
dataset_cfg = config.dataset
dataloader_cfg = config.dataloader
transform_cfg = config.transform_pipeline
# build transform
transform_pipeline = [build_transform(cfg) for cfg in transform_cfg]
transform_pipeline = transforms.Compose(transform_pipeline)
dataset_cfg['transform'] = transform_pipeline
# build dataset
dataset = build_dataset(dataset_cfg)
# build dataloader
dataloader = DataLoader(dataset=dataset, **dataloader_cfg)
data_iter = iter(dataloader)
img, label = data_iter.next()
if __name__ == '__main__':
test_cifar10_dataset()