mirror of https://github.com/hpcaitech/ColossalAI
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import os.path as osp
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.builder import build_dataset, ModelInitializer
|
|
from colossalai.core import global_context
|
|
from colossalai.initialize import init_dist
|
|
from colossalai.logging import get_global_dist_logger
|
|
|
|
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
|
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
|
|
|
|
|
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
|
@pytest.mark.dist
|
|
def test_partition():
|
|
init_dist(CONFIG_PATH)
|
|
logger = get_global_dist_logger()
|
|
logger.info('finished initialization')
|
|
|
|
# build model
|
|
model = ModelInitializer(global_context.config.model, 1, verbose=True).model_initialize()
|
|
logger.info('model is created')
|
|
|
|
dataset = build_dataset(global_context.config.train_data.dataset)
|
|
dataloader = DataLoader(dataset=dataset, **global_context.config.train_data.dataloader)
|
|
logger.info('train data is created')
|
|
|
|
global_context.destroy()
|
|
torch.cuda.synchronize()
|
|
logger.info('training finished')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_partition()
|