mirror of https://github.com/hpcaitech/ColossalAI
add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
parent
3defa32aee
commit
dbe62c67b8
@ -0,0 +1,14 @@
|
||||
# Overview
|
||||
|
||||
Here is an example of training ViT-B/16 on Imagenet-1K. We use 8x A100 in this example. For simplicity and speed, we didn't apply `RandAug` and we just used `Mixup`. With `LAMB` optimizer, we can scale the batch size to 32K with a little accuracy loss.
|
||||
|
||||
# How to run
|
||||
Using slurm:
|
||||
```shell
|
||||
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
|
||||
```
|
||||
|
||||
# Results
|
||||
|
||||
![Loss Curve](./loss.jpeg)
|
||||
![Accuracy](./acc.jpeg)
|
After Width: | Height: | Size: 19 KiB |
@ -0,0 +1,112 @@
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
|
||||
import nvidia.dali.fn as fn
|
||||
import nvidia.dali.types as types
|
||||
import nvidia.dali.tfrecord as tfrec
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DaliDataloader(DALIClassificationIterator):
|
||||
def __init__(self,
|
||||
tfrec_filenames,
|
||||
tfrec_idx_filenames,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
batch_size=128,
|
||||
num_threads=4,
|
||||
resize=256,
|
||||
crop=224,
|
||||
prefetch=2,
|
||||
training=True,
|
||||
gpu_aug=False,
|
||||
cuda=True,
|
||||
mixup_alpha=0.0):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.training = training
|
||||
pipe = Pipeline(batch_size=batch_size,
|
||||
num_threads=num_threads,
|
||||
device_id=torch.cuda.current_device() if cuda else None,
|
||||
seed=1024)
|
||||
with pipe:
|
||||
inputs = fn.readers.tfrecord(
|
||||
path=tfrec_filenames,
|
||||
index_path=tfrec_idx_filenames,
|
||||
random_shuffle=training,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
initial_fill=10000,
|
||||
read_ahead=True,
|
||||
prefetch_queue_depth=prefetch,
|
||||
name='Reader',
|
||||
features={
|
||||
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
|
||||
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||
})
|
||||
images = inputs["image/encoded"]
|
||||
|
||||
if training:
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
images = fn.random_resized_crop(images,
|
||||
size=crop,
|
||||
device='gpu' if gpu_aug else 'cpu')
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
else:
|
||||
# decode jpeg and resize
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
images = fn.resize(images,
|
||||
device='gpu' if gpu_aug else 'cpu',
|
||||
resize_x=resize,
|
||||
resize_y=resize,
|
||||
dtype=types.FLOAT,
|
||||
interp_type=types.INTERP_TRIANGULAR)
|
||||
flip_lr = False
|
||||
|
||||
# center crop and normalise
|
||||
images = fn.crop_mirror_normalize(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop),
|
||||
mean=[127.5],
|
||||
std=[127.5],
|
||||
mirror=flip_lr)
|
||||
label = inputs["image/class/label"] - 1 # 0-999
|
||||
# LSG: element_extract will raise exception, let's flatten outside
|
||||
# label = fn.element_extract(label, element_map=0) # Flatten
|
||||
if cuda: # transfer data to gpu
|
||||
pipe.set_outputs(images.gpu(), label.gpu())
|
||||
else:
|
||||
pipe.set_outputs(images, label)
|
||||
|
||||
pipe.build()
|
||||
last_batch_policy = 'DROP' if training else 'PARTIAL'
|
||||
super().__init__(pipe, reader_name="Reader",
|
||||
auto_reset=True,
|
||||
last_batch_policy=last_batch_policy)
|
||||
|
||||
def __iter__(self):
|
||||
# if not reset (after an epoch), reset; if just initialize, ignore
|
||||
if self._counter >= self._size or self._size < 0:
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
data = super().__next__()
|
||||
img, label = data[0]['data'], data[0]['label']
|
||||
label = label.squeeze()
|
||||
if self.mixup_alpha > 0.0:
|
||||
if self.training:
|
||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
idx = torch.randperm(img.size(0)).to(img.device)
|
||||
img = lam * img + (1 - lam) * img[idx, :]
|
||||
label_a, label_b = label, label[idx]
|
||||
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
|
||||
label = (label_a, label_b, lam)
|
||||
else:
|
||||
label = (label, label, torch.ones(
|
||||
1, device=img.device, dtype=img.dtype))
|
||||
return (img,), label
|
||||
return (img,), (label,)
|
@ -0,0 +1,15 @@
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.trainer import BaseHook
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class TotalBatchsizeHook(BaseHook):
|
||||
def __init__(self, trainer, priority: int = 2) -> None:
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
def before_train(self):
|
||||
total_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
||||
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|
After Width: | Height: | Size: 22 KiB |
@ -0,0 +1,12 @@
|
||||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
|
||||
@LOSSES.register_module
|
||||
class MixupLoss(nn.Module):
|
||||
def __init__(self, loss_fn_cls):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn_cls()
|
||||
|
||||
def forward(self, inputs, *args):
|
||||
targets_a, targets_b, lam = args
|
||||
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
|
@ -0,0 +1,70 @@
|
||||
import glob
|
||||
import os
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import set_global_multitimer_status
|
||||
from dataloader.imagenet_dali_dataloader import DaliDataloader
|
||||
|
||||
|
||||
def build_dali_train():
|
||||
root = gpc.config.dali.root
|
||||
train_pat = os.path.join(root, 'train/*')
|
||||
train_idx_pat = os.path.join(root, 'idx_files/train/*')
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(train_pat)),
|
||||
sorted(glob.glob(train_idx_pat)),
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=True,
|
||||
gpu_aug=gpc.config.dali.gpu_aug,
|
||||
cuda=True,
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||
)
|
||||
|
||||
|
||||
def build_dali_test():
|
||||
root = gpc.config.dali.root
|
||||
val_pat = os.path.join(root, 'validation/*')
|
||||
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(val_pat)),
|
||||
sorted(glob.glob(val_idx_pat)),
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=False,
|
||||
# gpu_aug=gpc.config.dali.gpu_aug,
|
||||
gpu_aug=False,
|
||||
cuda=True,
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
engine, train_dataloader, test_dataloader = colossalai.initialize(
|
||||
train_dataloader=build_dali_train,
|
||||
test_dataloader=build_dali_test
|
||||
)
|
||||
logger = get_global_dist_logger()
|
||||
set_global_multitimer_status(True)
|
||||
timer = colossalai.utils.get_global_multitimer()
|
||||
trainer = Trainer(engine=engine,
|
||||
verbose=True,
|
||||
timer=timer)
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks_cfg=gpc.config.hooks,
|
||||
display_progress=True,
|
||||
test_interval=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,78 @@
|
||||
from colossalai.engine import AMP_TYPE
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from mixup import MixupLoss
|
||||
from hooks import TotalBatchsizeHook
|
||||
from colossalai.registry import MODELS
|
||||
from timm.models import vit_base_patch16_224
|
||||
|
||||
MODELS.register_module(vit_base_patch16_224)
|
||||
|
||||
LOG_NAME = 'vit-b16-1k-32k-mixup-light2'
|
||||
# ViT Base
|
||||
BATCH_SIZE = 256
|
||||
DROP_RATE = 0.1
|
||||
NUM_EPOCHS = 300
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None),
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type='Lamb',
|
||||
lr=1.8e-2,
|
||||
weight_decay=0.1,
|
||||
)
|
||||
|
||||
|
||||
loss = dict(
|
||||
type='MixupLoss',
|
||||
loss_fn_cls=CrossEntropyLoss
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type='vit_base_patch16_224',
|
||||
drop_rate=DROP_RATE,
|
||||
)
|
||||
|
||||
hooks = [
|
||||
dict(type='LogMetricByEpochHook'),
|
||||
dict(type='AccuracyHook'),
|
||||
dict(type='LossHook'),
|
||||
dict(type='TotalBatchsizeHook'),
|
||||
dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'),
|
||||
dict(type='SaveCheckpointHook', interval=1,
|
||||
checkpoint_dir=f'./ckpt/{LOG_NAME}'),
|
||||
# dict(type='LoadCheckpointHook', epoch=10,
|
||||
# checkpoint_dir=f'./ckpt/{LOG_NAME}'),
|
||||
dict(
|
||||
type='LRSchedulerHook',
|
||||
by_epoch=True,
|
||||
lr_scheduler_cfg=dict(
|
||||
type='LinearWarmupLR',
|
||||
warmup_steps=150
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
fp16 = dict(
|
||||
mode=AMP_TYPE.TORCH,
|
||||
)
|
||||
|
||||
|
||||
logging = dict(
|
||||
root_path=f"./logs/{LOG_NAME}"
|
||||
)
|
||||
|
||||
dali = dict(
|
||||
root='./dataset/ILSVRC2012_1k',
|
||||
gpu_aug=True,
|
||||
mixup_alpha=0.2
|
||||
)
|
||||
|
||||
engine = dict(
|
||||
schedule=None,
|
||||
gradient_handlers=None,
|
||||
gradient_accumulation=16,
|
||||
gradient_clipping=1.0,
|
||||
)
|
Loading…
Reference in new issue