Browse Source

[unittest] refactored unit tests for change in dependency (#838)

pull/842/head^2
Frank Lee 3 years ago committed by GitHub
parent
commit
943982d29a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      tests/test_data/test_cifar10_dataset.py
  2. 28
      tests/test_data/test_data_parallel_sampler.py
  3. 22
      tests/test_data/test_deterministic_dataloader.py

21
tests/test_data/test_cifar10_dataset.py

@ -5,34 +5,21 @@ import os
from pathlib import Path from pathlib import Path
import pytest import pytest
from torchvision import transforms from torchvision import transforms, datasets
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.builder import build_dataset, build_transform
from colossalai.context import Config
from torchvision.transforms import ToTensor
TRAIN_DATA = dict(dataset=dict(type='CIFAR10', root=Path(os.environ['DATA']), train=True, download=True),
dataloader=dict(batch_size=4, shuffle=True, num_workers=2))
@pytest.mark.cpu @pytest.mark.cpu
def test_cifar10_dataset(): def test_cifar10_dataset():
config = Config(TRAIN_DATA)
dataset_cfg = config.dataset
dataloader_cfg = config.dataloader
transform_cfg = config.transform_pipeline
# build transform # build transform
transform_pipeline = [ToTensor()] transform_pipeline = [transforms.ToTensor()]
transform_pipeline = transforms.Compose(transform_pipeline) transform_pipeline = transforms.Compose(transform_pipeline)
dataset_cfg['transform'] = transform_pipeline
# build dataset # build dataset
dataset = build_dataset(dataset_cfg) dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
# build dataloader # build dataloader
dataloader = DataLoader(dataset=dataset, **dataloader_cfg) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)
data_iter = iter(dataloader) data_iter = iter(dataloader)
img, label = data_iter.next() img, label = data_iter.next()

28
tests/test_data/test_data_parallel_sampler.py

@ -9,28 +9,15 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import colossalai import colossalai
from colossalai.builder import build_dataset from torchvision import transforms, datasets
from torchvision import transforms
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from torchvision.transforms import ToTensor
CONFIG = Config( CONFIG = Config(dict(
dict(
train_data=dict(
dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(batch_size=8,),
),
parallel=dict( parallel=dict(
pipeline=dict(size=1), pipeline=dict(size=1),
tensor=dict(size=1, mode=None), tensor=dict(size=1, mode=None),
@ -44,11 +31,14 @@ def run_data_sampler(rank, world_size, port):
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
print('finished initialization') print('finished initialization')
transform_pipeline = [ToTensor()] # build dataset
transform_pipeline = [transforms.ToTensor()]
transform_pipeline = transforms.Compose(transform_pipeline) transform_pipeline = transforms.Compose(transform_pipeline)
gpc.config.train_data.dataset['transform'] = transform_pipeline dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
dataset = build_dataset(gpc.config.train_data.dataset)
dataloader = get_dataloader(dataset, **gpc.config.train_data.dataloader) # build dataloader
dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True)
data_iter = iter(dataloader) data_iter = iter(dataloader)
img, label = data_iter.next() img, label = data_iter.next()
img = img[0] img = img[0]

22
tests/test_data/test_deterministic_dataloader.py

@ -9,14 +9,12 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import colossalai import colossalai
from colossalai.builder import build_dataset
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from torchvision import transforms from torchvision import transforms
@ -43,20 +41,13 @@ def run_data_sampler(rank, world_size, port):
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
dataset_cfg = gpc.config.train_data.dataset
dataloader_cfg = gpc.config.train_data.dataloader
transform_cfg = gpc.config.train_data.transform_pipeline
# build transform
transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32)]
transform_pipeline = transforms.Compose(transform_pipeline)
dataset_cfg['transform'] = transform_pipeline
# build dataset # build dataset
dataset = build_dataset(dataset_cfg) transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)]
transform_pipeline = transforms.Compose(transform_pipeline)
dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline)
# build dataloader # build dataloader
dataloader = DataLoader(dataset=dataset, **dataloader_cfg) dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False)
data_iter = iter(dataloader) data_iter = iter(dataloader)
img, label = data_iter.next() img, label = data_iter.next()
@ -76,7 +67,6 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@pytest.mark.skip
@pytest.mark.cpu @pytest.mark.cpu
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_data_sampler(): def test_data_sampler():

Loading…
Cancel
Save