You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_trainer/test_pipeline/test_pipeline_schedule.py

97 lines
2.8 KiB

# referenced from Megatron and used to testify communication
import os
import os.path as osp
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 8
CONFIG=dict(
NUM_MICRO_BATCHES=2,
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=1, mode=None)
)
)
def run_schedule(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(
dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# initialize
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = engine.schedule
# run schedule
data_iter = iter(train_dataloader)
schedule.forward_backward_step(engine, data_iter)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule():
world_size = 2
run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_pipeline_schedule()