ColossalAI/examples/vit_b16_imagenet_data_parallel/train.py

122 lines
3.7 KiB
Python
Raw Normal View History

import glob
from math import log
import os
import colossalai
from colossalai.nn.metric import Accuracy
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from dataloader.imagenet_dali_dataloader import DaliDataloader
from mixup import MixupLoss, MixupAccuracy
from timm.models import vit_base_patch16_224
from myhooks import TotalBatchsizeHook
def build_dali_train():
root = gpc.config.dali.root
train_pat = os.path.join(root, 'train/*')
train_idx_pat = os.path.join(root, 'idx_files/train/*')
return DaliDataloader(
sorted(glob.glob(train_pat)),
sorted(glob.glob(train_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=gpc.config.dali.gpu_aug,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def build_dali_test():
root = gpc.config.dali.root
val_pat = os.path.join(root, 'validation/*')
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
return DaliDataloader(
sorted(glob.glob(val_pat)),
sorted(glob.glob(val_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
# gpu_aug=gpc.config.dali.gpu_aug,
gpu_aug=False,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from slurm batch job
colossalai.launch_from_slurm(config=args.config,
host=args.host,
port=args.port,
backend=args.backend
)
# launch from torch
# colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
# build model
model = vit_base_patch16_224(drop_rate=0.1)
# build dataloader
train_dataloader = build_dali_train()
test_dataloader = build_dali_test()
# build optimizer
optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
# build loss
criterion = MixupLoss(loss_fn_cls=torch.nn.CrossEntropyLoss)
# lr_scheduelr
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
model, optimizer, criterion, train_dataloader, test_dataloader
)
logger.info("initialized colossalai components", ranks=[0])
# build trainer
trainer = Trainer(engine=engine, logger=logger)
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),
# comment if you do not need to use the hooks below
hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
]
# start training
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()