ColossalAI/tests/test_pipeline/test_cuda_rpc_performance.py

91 lines
2.8 KiB
Python
Raw Normal View History

import os
from typing import Callable, List, Optional, Type, Union
import time
import pytest
import torch
import torch.nn as nn
from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
from tqdm import tqdm
from rpc_test_utils import rpc_run, parse_args
import colossalai
import colossalai.nn as col_nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
from colossalai.pipeline.pipeline_process_group import ppg
def flatten(x):
return torch.flatten(x, 1)
def partition(pp_rank: int, chunk: int, stage_num: int):
pipelinable = PipelinableContext()
# build model partitions
with pipelinable:
# input : [B, 3, 32, 32]
_ = resnet50()
pipelinable.policy = "customized"
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
partition = pipelinable.partition(chunk, stage_num, pp_rank)
return partition
def run_master(args):
batch_size = args.batch_size
chunk = args.chunk
device = args.device
world_size = args.world_size
stage_num = world_size
num_microbatches = args.num_microbatches
# build dataloader
root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
criterion = nn.CrossEntropyLoss()
pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
criterion=criterion,
checkpoint=False)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
s = time.time()
for bx, by in tqdm(train_dataloader):
pp_engine.forward_backward(bx, labels=by, forward_only=False)
cost_time = time.time() - s
print("total cost time :", cost_time)
print("cost time per batch:", cost_time / len(train_dataloader))
@pytest.mark.skip("Test for performance, no need for CI")
def main():
args = parse_args()
# this is due to limitation of partition function
args.world_size = 2
args.chunk = 1
rpc_run(args, run_master)
if __name__ == '__main__':
main()