ColossalAI/benchmark/imagenet100/train.py

208 lines
8.1 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import glob
import os
import colossalai
import nvidia.dali.fn as fn
import nvidia.dali.tfrecord as tfrec
import torch
from colossalai.builder import *
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
DATASET_PATH = str(os.environ['DATA'])
TRAIN_RECS = DATASET_PATH + '/train/*'
VAL_RECS = DATASET_PATH + '/validation/*'
TRAIN_IDX = DATASET_PATH + '/idx_files/train/*'
VAL_IDX = DATASET_PATH + '/idx_files/validation/*'
class DaliDataloader(DALIClassificationIterator):
def __init__(self,
tfrec_filenames,
tfrec_idx_filenames,
shard_id=0,
num_shards=1,
batch_size=128,
num_threads=4,
resize=256,
crop=224,
prefetch=2,
training=True,
gpu_aug=False,
cuda=True):
pipe = Pipeline(batch_size=batch_size,
num_threads=num_threads,
device_id=torch.cuda.current_device() if cuda else None,
seed=1024)
with pipe:
inputs = fn.readers.tfrecord(path=tfrec_filenames,
index_path=tfrec_idx_filenames,
random_shuffle=training,
shard_id=shard_id,
num_shards=num_shards,
initial_fill=10000,
read_ahead=True,
prefetch_queue_depth=prefetch,
name='Reader',
features={
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
})
images = inputs["image/encoded"]
if training:
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu')
flip_lr = fn.random.coin_flip(probability=0.5)
else:
# decode jpeg and resize
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.resize(images,
device='gpu' if gpu_aug else 'cpu',
resize_x=resize,
resize_y=resize,
dtype=types.FLOAT,
interp_type=types.INTERP_TRIANGULAR)
flip_lr = False
# center crop and normalise
images = fn.crop_mirror_normalize(images,
dtype=types.FLOAT,
crop=(crop, crop),
mean=[127.5],
std=[127.5],
mirror=flip_lr)
label = inputs["image/class/label"] - 1 # 0-999
# LSG: element_extract will raise exception, let's flatten outside
# label = fn.element_extract(label, element_map=0) # Flatten
if cuda: # transfer data to gpu
pipe.set_outputs(images.gpu(), label.gpu())
else:
pipe.set_outputs(images, label)
pipe.build()
last_batch_policy = 'DROP' if training else 'PARTIAL'
super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
def __iter__(self):
# if not reset (after an epoch), reset; if just initialize, ignore
if self._counter >= self._size or self._size < 0:
self.reset()
return self
def __next__(self):
data = super().__next__()
img, label = data[0]['data'], data[0]['label']
label = label.squeeze()
return (img, ), (label, )
def build_dali_train(batch_size):
return DaliDataloader(
sorted(glob.glob(TRAIN_RECS)),
sorted(glob.glob(TRAIN_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=True,
cuda=True,
)
def build_dali_test(batch_size):
return DaliDataloader(
sorted(glob.glob(VAL_RECS)),
sorted(glob.glob(VAL_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
gpu_aug=True,
cuda=True,
)
def train_imagenet():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
model = vit_small_patch16_224(num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hook_list = [
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
# hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# hooks.LogMemoryByEpochHook(logger=logger),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.LossHook(),
hooks.ThroughputHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hook_list,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_imagenet()