mirror of https://github.com/hpcaitech/ColossalAI
add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
parent
3defa32aee
commit
dbe62c67b8
@ -0,0 +1,14 @@
|
|||||||
|
# Overview
|
||||||
|
|
||||||
|
Here is an example of training ViT-B/16 on Imagenet-1K. We use 8x A100 in this example. For simplicity and speed, we didn't apply `RandAug` and we just used `Mixup`. With `LAMB` optimizer, we can scale the batch size to 32K with a little accuracy loss.
|
||||||
|
|
||||||
|
# How to run
|
||||||
|
Using slurm:
|
||||||
|
```shell
|
||||||
|
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
|
||||||
|
```
|
||||||
|
|
||||||
|
# Results
|
||||||
|
|
||||||
|
![Loss Curve](./loss.jpeg)
|
||||||
|
![Accuracy](./acc.jpeg)
|
After Width: | Height: | Size: 19 KiB |
@ -0,0 +1,112 @@
|
|||||||
|
from nvidia.dali.pipeline import Pipeline
|
||||||
|
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
|
||||||
|
import nvidia.dali.fn as fn
|
||||||
|
import nvidia.dali.types as types
|
||||||
|
import nvidia.dali.tfrecord as tfrec
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DaliDataloader(DALIClassificationIterator):
|
||||||
|
def __init__(self,
|
||||||
|
tfrec_filenames,
|
||||||
|
tfrec_idx_filenames,
|
||||||
|
shard_id=0,
|
||||||
|
num_shards=1,
|
||||||
|
batch_size=128,
|
||||||
|
num_threads=4,
|
||||||
|
resize=256,
|
||||||
|
crop=224,
|
||||||
|
prefetch=2,
|
||||||
|
training=True,
|
||||||
|
gpu_aug=False,
|
||||||
|
cuda=True,
|
||||||
|
mixup_alpha=0.0):
|
||||||
|
self.mixup_alpha = mixup_alpha
|
||||||
|
self.training = training
|
||||||
|
pipe = Pipeline(batch_size=batch_size,
|
||||||
|
num_threads=num_threads,
|
||||||
|
device_id=torch.cuda.current_device() if cuda else None,
|
||||||
|
seed=1024)
|
||||||
|
with pipe:
|
||||||
|
inputs = fn.readers.tfrecord(
|
||||||
|
path=tfrec_filenames,
|
||||||
|
index_path=tfrec_idx_filenames,
|
||||||
|
random_shuffle=training,
|
||||||
|
shard_id=shard_id,
|
||||||
|
num_shards=num_shards,
|
||||||
|
initial_fill=10000,
|
||||||
|
read_ahead=True,
|
||||||
|
prefetch_queue_depth=prefetch,
|
||||||
|
name='Reader',
|
||||||
|
features={
|
||||||
|
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
|
||||||
|
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||||
|
})
|
||||||
|
images = inputs["image/encoded"]
|
||||||
|
|
||||||
|
if training:
|
||||||
|
images = fn.decoders.image(images,
|
||||||
|
device='mixed' if gpu_aug else 'cpu',
|
||||||
|
output_type=types.RGB)
|
||||||
|
images = fn.random_resized_crop(images,
|
||||||
|
size=crop,
|
||||||
|
device='gpu' if gpu_aug else 'cpu')
|
||||||
|
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||||
|
else:
|
||||||
|
# decode jpeg and resize
|
||||||
|
images = fn.decoders.image(images,
|
||||||
|
device='mixed' if gpu_aug else 'cpu',
|
||||||
|
output_type=types.RGB)
|
||||||
|
images = fn.resize(images,
|
||||||
|
device='gpu' if gpu_aug else 'cpu',
|
||||||
|
resize_x=resize,
|
||||||
|
resize_y=resize,
|
||||||
|
dtype=types.FLOAT,
|
||||||
|
interp_type=types.INTERP_TRIANGULAR)
|
||||||
|
flip_lr = False
|
||||||
|
|
||||||
|
# center crop and normalise
|
||||||
|
images = fn.crop_mirror_normalize(images,
|
||||||
|
dtype=types.FLOAT,
|
||||||
|
crop=(crop, crop),
|
||||||
|
mean=[127.5],
|
||||||
|
std=[127.5],
|
||||||
|
mirror=flip_lr)
|
||||||
|
label = inputs["image/class/label"] - 1 # 0-999
|
||||||
|
# LSG: element_extract will raise exception, let's flatten outside
|
||||||
|
# label = fn.element_extract(label, element_map=0) # Flatten
|
||||||
|
if cuda: # transfer data to gpu
|
||||||
|
pipe.set_outputs(images.gpu(), label.gpu())
|
||||||
|
else:
|
||||||
|
pipe.set_outputs(images, label)
|
||||||
|
|
||||||
|
pipe.build()
|
||||||
|
last_batch_policy = 'DROP' if training else 'PARTIAL'
|
||||||
|
super().__init__(pipe, reader_name="Reader",
|
||||||
|
auto_reset=True,
|
||||||
|
last_batch_policy=last_batch_policy)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# if not reset (after an epoch), reset; if just initialize, ignore
|
||||||
|
if self._counter >= self._size or self._size < 0:
|
||||||
|
self.reset()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
data = super().__next__()
|
||||||
|
img, label = data[0]['data'], data[0]['label']
|
||||||
|
label = label.squeeze()
|
||||||
|
if self.mixup_alpha > 0.0:
|
||||||
|
if self.training:
|
||||||
|
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||||
|
idx = torch.randperm(img.size(0)).to(img.device)
|
||||||
|
img = lam * img + (1 - lam) * img[idx, :]
|
||||||
|
label_a, label_b = label, label[idx]
|
||||||
|
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
|
||||||
|
label = (label_a, label_b, lam)
|
||||||
|
else:
|
||||||
|
label = (label, label, torch.ones(
|
||||||
|
1, device=img.device, dtype=img.dtype))
|
||||||
|
return (img,), label
|
||||||
|
return (img,), (label,)
|
@ -0,0 +1,15 @@
|
|||||||
|
from colossalai.registry import HOOKS
|
||||||
|
from colossalai.trainer import BaseHook
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module
|
||||||
|
class TotalBatchsizeHook(BaseHook):
|
||||||
|
def __init__(self, trainer, priority: int = 2) -> None:
|
||||||
|
super().__init__(trainer, priority)
|
||||||
|
|
||||||
|
def before_train(self):
|
||||||
|
total_batch_size = gpc.config.BATCH_SIZE * \
|
||||||
|
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|
After Width: | Height: | Size: 22 KiB |
@ -0,0 +1,12 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.registry import LOSSES
|
||||||
|
|
||||||
|
@LOSSES.register_module
|
||||||
|
class MixupLoss(nn.Module):
|
||||||
|
def __init__(self, loss_fn_cls):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_fn = loss_fn_cls()
|
||||||
|
|
||||||
|
def forward(self, inputs, *args):
|
||||||
|
targets_a, targets_b, lam = args
|
||||||
|
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)
|
@ -0,0 +1,70 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import get_global_dist_logger
|
||||||
|
from colossalai.trainer import Trainer
|
||||||
|
from colossalai.utils import set_global_multitimer_status
|
||||||
|
from dataloader.imagenet_dali_dataloader import DaliDataloader
|
||||||
|
|
||||||
|
|
||||||
|
def build_dali_train():
|
||||||
|
root = gpc.config.dali.root
|
||||||
|
train_pat = os.path.join(root, 'train/*')
|
||||||
|
train_idx_pat = os.path.join(root, 'idx_files/train/*')
|
||||||
|
return DaliDataloader(
|
||||||
|
sorted(glob.glob(train_pat)),
|
||||||
|
sorted(glob.glob(train_idx_pat)),
|
||||||
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||||
|
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||||
|
training=True,
|
||||||
|
gpu_aug=gpc.config.dali.gpu_aug,
|
||||||
|
cuda=True,
|
||||||
|
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dali_test():
|
||||||
|
root = gpc.config.dali.root
|
||||||
|
val_pat = os.path.join(root, 'validation/*')
|
||||||
|
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
|
||||||
|
return DaliDataloader(
|
||||||
|
sorted(glob.glob(val_pat)),
|
||||||
|
sorted(glob.glob(val_idx_pat)),
|
||||||
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||||
|
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||||
|
training=False,
|
||||||
|
# gpu_aug=gpc.config.dali.gpu_aug,
|
||||||
|
gpu_aug=False,
|
||||||
|
cuda=True,
|
||||||
|
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
engine, train_dataloader, test_dataloader = colossalai.initialize(
|
||||||
|
train_dataloader=build_dali_train,
|
||||||
|
test_dataloader=build_dali_test
|
||||||
|
)
|
||||||
|
logger = get_global_dist_logger()
|
||||||
|
set_global_multitimer_status(True)
|
||||||
|
timer = colossalai.utils.get_global_multitimer()
|
||||||
|
trainer = Trainer(engine=engine,
|
||||||
|
verbose=True,
|
||||||
|
timer=timer)
|
||||||
|
|
||||||
|
trainer.fit(
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
test_dataloader=test_dataloader,
|
||||||
|
epochs=gpc.config.NUM_EPOCHS,
|
||||||
|
hooks_cfg=gpc.config.hooks,
|
||||||
|
display_progress=True,
|
||||||
|
test_interval=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,78 @@
|
|||||||
|
from colossalai.engine import AMP_TYPE
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from mixup import MixupLoss
|
||||||
|
from hooks import TotalBatchsizeHook
|
||||||
|
from colossalai.registry import MODELS
|
||||||
|
from timm.models import vit_base_patch16_224
|
||||||
|
|
||||||
|
MODELS.register_module(vit_base_patch16_224)
|
||||||
|
|
||||||
|
LOG_NAME = 'vit-b16-1k-32k-mixup-light2'
|
||||||
|
# ViT Base
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
DROP_RATE = 0.1
|
||||||
|
NUM_EPOCHS = 300
|
||||||
|
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=dict(size=1),
|
||||||
|
tensor=dict(size=1, mode=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = dict(
|
||||||
|
type='Lamb',
|
||||||
|
lr=1.8e-2,
|
||||||
|
weight_decay=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
loss = dict(
|
||||||
|
type='MixupLoss',
|
||||||
|
loss_fn_cls=CrossEntropyLoss
|
||||||
|
)
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='vit_base_patch16_224',
|
||||||
|
drop_rate=DROP_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = [
|
||||||
|
dict(type='LogMetricByEpochHook'),
|
||||||
|
dict(type='AccuracyHook'),
|
||||||
|
dict(type='LossHook'),
|
||||||
|
dict(type='TotalBatchsizeHook'),
|
||||||
|
dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'),
|
||||||
|
dict(type='SaveCheckpointHook', interval=1,
|
||||||
|
checkpoint_dir=f'./ckpt/{LOG_NAME}'),
|
||||||
|
# dict(type='LoadCheckpointHook', epoch=10,
|
||||||
|
# checkpoint_dir=f'./ckpt/{LOG_NAME}'),
|
||||||
|
dict(
|
||||||
|
type='LRSchedulerHook',
|
||||||
|
by_epoch=True,
|
||||||
|
lr_scheduler_cfg=dict(
|
||||||
|
type='LinearWarmupLR',
|
||||||
|
warmup_steps=150
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
fp16 = dict(
|
||||||
|
mode=AMP_TYPE.TORCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logging = dict(
|
||||||
|
root_path=f"./logs/{LOG_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dali = dict(
|
||||||
|
root='./dataset/ILSVRC2012_1k',
|
||||||
|
gpu_aug=True,
|
||||||
|
mixup_alpha=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = dict(
|
||||||
|
schedule=None,
|
||||||
|
gradient_handlers=None,
|
||||||
|
gradient_accumulation=16,
|
||||||
|
gradient_clipping=1.0,
|
||||||
|
)
|
Loading…
Reference in new issue