diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py deleted file mode 100644 index 283b5cc35..000000000 --- a/tests/test_data/test_deterministic_dataloader.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import os -from pathlib import Path - -import pytest -import torch -import torch.distributed as dist -from torchvision import datasets, transforms - -import colossalai -from colossalai.context import Config, ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_dataloader - -CONFIG = Config( - dict( - train_data=dict( - dataset=dict( - type='CIFAR10', - root=Path(os.environ['DATA']), - train=True, - download=True, - ), - dataloader=dict(num_workers=2, batch_size=2, shuffle=True), - ), - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, - )) - - -def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') - colossalai.launch(**dist_args) - - # build dataset - transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] - transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) - - # build dataloader - dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) - - data_iter = iter(dataloader) - img, label = data_iter.next() - img = img[0] - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - img_to_compare = img.clone() - else: - img_to_compare = img - dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) - - if gpc.get_local_rank(ParallelMode.DATA) != 0: - # this is without sampler - # this should be false if data parallel sampler to given to the dataloader - assert torch.equal(img, - img_to_compare), 'Same image was distributed across ranks and expected it to be the same' - torch.cuda.empty_cache() - - -@rerun_if_address_is_in_use() -def test_data_sampler(): - spawn(run_data_sampler, 4) - - -if __name__ == '__main__': - test_data_sampler()