ColossalAI/tests/test_legacy/test_data/test_deterministic_dataload...

76 lines
2.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pathlib import Path
import torch
import torch.distributed as dist
from torchvision import datasets, transforms
import colossalai
from colossalai.context import Config
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import get_dataloader
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = Config(
dict(
train_data=dict(
dataset=dict(
type="CIFAR10",
root=Path(os.environ["DATA"]),
train=True,
download=True,
),
dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
),
seed=1024,
)
)
def run_data_sampler(rank, world_size, port):
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost")
colossalai.legacy.launch(**dist_args)
# build dataset
transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
transform_pipeline = transforms.Compose(transform_pipeline)
dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline)
# build dataloader
dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
data_iter = iter(dataloader)
img, label = data_iter.next()
img = img[0]
if gpc.get_local_rank(ParallelMode.DATA) != 0:
img_to_compare = img.clone()
else:
img_to_compare = img
dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))
if gpc.get_local_rank(ParallelMode.DATA) != 0:
# this is without sampler
# this should be false if data parallel sampler to given to the dataloader
assert torch.equal(
img, img_to_compare
), "Same image was distributed across ranks and expected it to be the same"
torch.cuda.empty_cache()
@rerun_if_address_is_in_use()
def test_data_sampler():
spawn(run_data_sampler, 4)
if __name__ == "__main__":
test_data_sampler()