|
|
|
@ -9,14 +9,12 @@ import pytest
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
import torch.multiprocessing as mp |
|
|
|
|
from torchvision import transforms |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
from torchvision import transforms, datasets |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.builder import build_dataset |
|
|
|
|
from colossalai.context import ParallelMode, Config |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.utils import free_port |
|
|
|
|
from colossalai.utils import get_dataloader, free_port |
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use |
|
|
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
@ -43,20 +41,13 @@ def run_data_sampler(rank, world_size, port):
|
|
|
|
|
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') |
|
|
|
|
colossalai.launch(**dist_args) |
|
|
|
|
|
|
|
|
|
dataset_cfg = gpc.config.train_data.dataset |
|
|
|
|
dataloader_cfg = gpc.config.train_data.dataloader |
|
|
|
|
transform_cfg = gpc.config.train_data.transform_pipeline |
|
|
|
|
|
|
|
|
|
# build transform |
|
|
|
|
transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32)] |
|
|
|
|
transform_pipeline = transforms.Compose(transform_pipeline) |
|
|
|
|
dataset_cfg['transform'] = transform_pipeline |
|
|
|
|
|
|
|
|
|
# build dataset |
|
|
|
|
dataset = build_dataset(dataset_cfg) |
|
|
|
|
transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] |
|
|
|
|
transform_pipeline = transforms.Compose(transform_pipeline) |
|
|
|
|
dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) |
|
|
|
|
|
|
|
|
|
# build dataloader |
|
|
|
|
dataloader = DataLoader(dataset=dataset, **dataloader_cfg) |
|
|
|
|
dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) |
|
|
|
|
|
|
|
|
|
data_iter = iter(dataloader) |
|
|
|
|
img, label = data_iter.next() |
|
|
|
@ -76,7 +67,6 @@ def run_data_sampler(rank, world_size, port):
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip |
|
|
|
|
@pytest.mark.cpu |
|
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
|
def test_data_sampler(): |
|
|
|
|