ColossalAI/tests/test_trainer/test_pipeline/test_pipeline_schedule.py

88 lines
2.5 KiB
Python
Raw Normal View History

2021-12-16 02:32:08 +00:00
# referenced from Megatron and used to testify communication
import os
from pathlib import Path
2021-12-16 02:32:08 +00:00
import pytest
import torch
import torch.nn as nn
2021-12-16 02:32:08 +00:00
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
2021-12-16 02:32:08 +00:00
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader, print_rank_0
2021-12-16 02:32:08 +00:00
BATCH_SIZE = 8
2021-12-16 02:32:08 +00:00
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None)))
2021-12-16 02:32:08 +00:00
def run_schedule(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
2021-12-16 02:32:08 +00:00
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
2021-12-16 02:32:08 +00:00
print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(
dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
2021-12-16 02:32:08 +00:00
)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# initialize
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = engine.schedule
2021-12-16 02:32:08 +00:00
# run schedule
data_iter = iter(train_dataloader)
schedule.forward_backward_step(engine, data_iter)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
2021-12-16 02:32:08 +00:00
def test_pipeline_schedule():
world_size = 2
spawn(run_schedule, world_size)
2021-12-16 02:32:08 +00:00
if __name__ == '__main__':
test_pipeline_schedule()