ColossalAI/tests/test_trainer/test_pipeline/test_pipeline_schedule.py

95 lines
2.7 KiB
Python
Raw Normal View History

2021-12-16 02:32:08 +00:00
# referenced from Megatron and used to testify communication
import colossalai
import os
import os.path as osp
import pytest
import torch
import torch.multiprocessing as mp
import model
from colossalai.builder import PipelineModelInitializer
from colossalai.communication import p2p as p2p_communication
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import print_rank_0, get_current_device, get_dataloader
from colossalai.engine.schedule import PipelineSchedule
from torchvision.datasets import CIFAR10
from torchvision import transforms
from pathlib import Path
from functools import partial
BATCH_SIZE = 32
NUM_MICRO = 8
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
def run_schedule(rank, world_size):
launch(config=CONFIG_PATH,
rank=rank,
world_size=world_size,
host='localhost',
port=29934,
backend='nccl')
# build model
model = PipelineModelInitializer(gpc.config.model, 1).initialize()
print_rank_0('model is created')
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010]),
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# initialize
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
# run schedule
data_iter = iter(train_dataloader)
schedule.forward_backward_step(engine, data_iter)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_pipeline_schedule():
world_size = 4
run_func = partial(run_schedule, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_pipeline_schedule()