ColossalAI/benchmark/cifar/train.py

120 lines
4.7 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import colossalai
import torch
import torchvision
from colossalai.builder import *
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
DATASET_PATH = str(os.environ['DATA'])
def build_cifar(batch_size):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
train=True,
download=True,
transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
return train_dataloader, test_dataloader
def train_cifar():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
model = vit_lite_depth7_patch4_32()
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
steps_per_epoch = len(train_dataloader)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
lr_scheduler=lr_scheduler)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hook_list = [
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
# hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# hooks.LogMemoryByEpochHook(logger=logger),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.LossHook(),
hooks.ThroughputHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_cifar()