# referenced from Megatron and used to testify communication import os import os.path as osp from functools import partial from pathlib import Path import colossalai import pytest import torch import torch.nn as nn import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.context import ParallelMode from colossalai.initialize import launch from colossalai.utils import free_port, get_dataloader, print_rank_0 from colossalai.testing import rerun_on_exception from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 BATCH_SIZE = 8 CONFIG=dict( NUM_MICRO_BATCHES=2, parallel = dict( pipeline=dict(size=2), tensor=dict(size=1, mode=None) ) ) def run_schedule(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model model = resnet18(num_classes=10) if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2) elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: class Flatten(nn.Module): def forward(self, x): return torch.flatten(x, 1) model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) print_rank_0('model is created') train_dataset = CIFAR10(root=Path(os.environ['DATA']), download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), ])) train_dataloader = get_dataloader( dataset=train_dataset, shuffle=True, add_sampler=True, batch_size=BATCH_SIZE, pin_memory=True, ) # build criterion criterion = torch.nn.CrossEntropyLoss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) # initialize engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader) # build pipeline schedule schedule = engine.schedule # run schedule data_iter = iter(train_dataloader) schedule.forward_backward_step(engine, data_iter) gpc.destroy() torch.cuda.empty_cache() @pytest.mark.dist @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_pipeline_schedule(): world_size = 2 run_func = partial(run_schedule, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': test_pipeline_schedule()