mirror of https://github.com/hpcaitech/ColossalAI
Pengtai Xu
1 year ago
1 changed files with 0 additions and 73 deletions
@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env python |
||||
# -*- encoding: utf-8 -*- |
||||
|
||||
import os |
||||
from pathlib import Path |
||||
|
||||
import pytest |
||||
import torch |
||||
import torch.distributed as dist |
||||
from torchvision import datasets, transforms |
||||
|
||||
import colossalai |
||||
from colossalai.context import Config, ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn |
||||
from colossalai.utils import get_dataloader |
||||
|
||||
CONFIG = Config( |
||||
dict( |
||||
train_data=dict( |
||||
dataset=dict( |
||||
type='CIFAR10', |
||||
root=Path(os.environ['DATA']), |
||||
train=True, |
||||
download=True, |
||||
), |
||||
dataloader=dict(num_workers=2, batch_size=2, shuffle=True), |
||||
), |
||||
parallel=dict( |
||||
pipeline=dict(size=1), |
||||
tensor=dict(size=1, mode=None), |
||||
), |
||||
seed=1024, |
||||
)) |
||||
|
||||
|
||||
def run_data_sampler(rank, world_size, port): |
||||
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') |
||||
colossalai.launch(**dist_args) |
||||
|
||||
# build dataset |
||||
transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] |
||||
transform_pipeline = transforms.Compose(transform_pipeline) |
||||
dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) |
||||
|
||||
# build dataloader |
||||
dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) |
||||
|
||||
data_iter = iter(dataloader) |
||||
img, label = data_iter.next() |
||||
img = img[0] |
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) != 0: |
||||
img_to_compare = img.clone() |
||||
else: |
||||
img_to_compare = img |
||||
dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) |
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) != 0: |
||||
# this is without sampler |
||||
# this should be false if data parallel sampler to given to the dataloader |
||||
assert torch.equal(img, |
||||
img_to_compare), 'Same image was distributed across ranks and expected it to be the same' |
||||
torch.cuda.empty_cache() |
||||
|
||||
|
||||
@rerun_if_address_is_in_use() |
||||
def test_data_sampler(): |
||||
spawn(run_data_sampler, 4) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_data_sampler() |
Loading…
Reference in new issue