add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

pull/35/head
ver217 3 years ago committed by GitHub
parent 3defa32aee
commit dbe62c67b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -94,7 +94,7 @@ class Lamb(Optimizer):
# * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr']
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
weight_norm = p.data.pow(2).sum().sqrt()
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:

@ -0,0 +1,14 @@
# Overview
Here is an example of training ViT-B/16 on Imagenet-1K. We use 8x A100 in this example. For simplicity and speed, we didn't apply `RandAug` and we just used `Mixup`. With `LAMB` optimizer, we can scale the batch size to 32K with a little accuracy loss.
# How to run
Using slurm:
```shell
srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py
```
# Results
![Loss Curve](./loss.jpeg)
![Accuracy](./acc.jpeg)

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

@ -0,0 +1,112 @@
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
import torch
import numpy as np
class DaliDataloader(DALIClassificationIterator):
def __init__(self,
tfrec_filenames,
tfrec_idx_filenames,
shard_id=0,
num_shards=1,
batch_size=128,
num_threads=4,
resize=256,
crop=224,
prefetch=2,
training=True,
gpu_aug=False,
cuda=True,
mixup_alpha=0.0):
self.mixup_alpha = mixup_alpha
self.training = training
pipe = Pipeline(batch_size=batch_size,
num_threads=num_threads,
device_id=torch.cuda.current_device() if cuda else None,
seed=1024)
with pipe:
inputs = fn.readers.tfrecord(
path=tfrec_filenames,
index_path=tfrec_idx_filenames,
random_shuffle=training,
shard_id=shard_id,
num_shards=num_shards,
initial_fill=10000,
read_ahead=True,
prefetch_queue_depth=prefetch,
name='Reader',
features={
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
})
images = inputs["image/encoded"]
if training:
images = fn.decoders.image(images,
device='mixed' if gpu_aug else 'cpu',
output_type=types.RGB)
images = fn.random_resized_crop(images,
size=crop,
device='gpu' if gpu_aug else 'cpu')
flip_lr = fn.random.coin_flip(probability=0.5)
else:
# decode jpeg and resize
images = fn.decoders.image(images,
device='mixed' if gpu_aug else 'cpu',
output_type=types.RGB)
images = fn.resize(images,
device='gpu' if gpu_aug else 'cpu',
resize_x=resize,
resize_y=resize,
dtype=types.FLOAT,
interp_type=types.INTERP_TRIANGULAR)
flip_lr = False
# center crop and normalise
images = fn.crop_mirror_normalize(images,
dtype=types.FLOAT,
crop=(crop, crop),
mean=[127.5],
std=[127.5],
mirror=flip_lr)
label = inputs["image/class/label"] - 1 # 0-999
# LSG: element_extract will raise exception, let's flatten outside
# label = fn.element_extract(label, element_map=0) # Flatten
if cuda: # transfer data to gpu
pipe.set_outputs(images.gpu(), label.gpu())
else:
pipe.set_outputs(images, label)
pipe.build()
last_batch_policy = 'DROP' if training else 'PARTIAL'
super().__init__(pipe, reader_name="Reader",
auto_reset=True,
last_batch_policy=last_batch_policy)
def __iter__(self):
# if not reset (after an epoch), reset; if just initialize, ignore
if self._counter >= self._size or self._size < 0:
self.reset()
return self
def __next__(self):
data = super().__next__()
img, label = data[0]['data'], data[0]['label']
label = label.squeeze()
if self.mixup_alpha > 0.0:
if self.training:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
idx = torch.randperm(img.size(0)).to(img.device)
img = lam * img + (1 - lam) * img[idx, :]
label_a, label_b = label, label[idx]
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
label = (label_a, label_b, lam)
else:
label = (label, label, torch.ones(
1, device=img.device, dtype=img.dtype))
return (img,), label
return (img,), (label,)

@ -0,0 +1,15 @@
from colossalai.registry import HOOKS
from colossalai.trainer import BaseHook
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
@HOOKS.register_module
class TotalBatchsizeHook(BaseHook):
def __init__(self, trainer, priority: int = 2) -> None:
super().__init__(trainer, priority)
def before_train(self):
total_batch_size = gpc.config.BATCH_SIZE * \
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

@ -0,0 +1,12 @@
import torch.nn as nn
from colossalai.registry import LOSSES
@LOSSES.register_module
class MixupLoss(nn.Module):
def __init__(self, loss_fn_cls):
super().__init__()
self.loss_fn = loss_fn_cls()
def forward(self, inputs, *args):
targets_a, targets_b, lam = args
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)

@ -0,0 +1,70 @@
import glob
import os
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import set_global_multitimer_status
from dataloader.imagenet_dali_dataloader import DaliDataloader
def build_dali_train():
root = gpc.config.dali.root
train_pat = os.path.join(root, 'train/*')
train_idx_pat = os.path.join(root, 'idx_files/train/*')
return DaliDataloader(
sorted(glob.glob(train_pat)),
sorted(glob.glob(train_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=gpc.config.dali.gpu_aug,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def build_dali_test():
root = gpc.config.dali.root
val_pat = os.path.join(root, 'validation/*')
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
return DaliDataloader(
sorted(glob.glob(val_pat)),
sorted(glob.glob(val_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
# gpu_aug=gpc.config.dali.gpu_aug,
gpu_aug=False,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def main():
engine, train_dataloader, test_dataloader = colossalai.initialize(
train_dataloader=build_dali_train,
test_dataloader=build_dali_test
)
logger = get_global_dist_logger()
set_global_multitimer_status(True)
timer = colossalai.utils.get_global_multitimer()
trainer = Trainer(engine=engine,
verbose=True,
timer=timer)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()

@ -0,0 +1,78 @@
from colossalai.engine import AMP_TYPE
from torch.nn import CrossEntropyLoss
from mixup import MixupLoss
from hooks import TotalBatchsizeHook
from colossalai.registry import MODELS
from timm.models import vit_base_patch16_224
MODELS.register_module(vit_base_patch16_224)
LOG_NAME = 'vit-b16-1k-32k-mixup-light2'
# ViT Base
BATCH_SIZE = 256
DROP_RATE = 0.1
NUM_EPOCHS = 300
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
)
optimizer = dict(
type='Lamb',
lr=1.8e-2,
weight_decay=0.1,
)
loss = dict(
type='MixupLoss',
loss_fn_cls=CrossEntropyLoss
)
model = dict(
type='vit_base_patch16_224',
drop_rate=DROP_RATE,
)
hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='AccuracyHook'),
dict(type='LossHook'),
dict(type='TotalBatchsizeHook'),
dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'),
dict(type='SaveCheckpointHook', interval=1,
checkpoint_dir=f'./ckpt/{LOG_NAME}'),
# dict(type='LoadCheckpointHook', epoch=10,
# checkpoint_dir=f'./ckpt/{LOG_NAME}'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=150
)
),
]
fp16 = dict(
mode=AMP_TYPE.TORCH,
)
logging = dict(
root_path=f"./logs/{LOG_NAME}"
)
dali = dict(
root='./dataset/ILSVRC2012_1k',
gpu_aug=True,
mixup_alpha=0.2
)
engine = dict(
schedule=None,
gradient_handlers=None,
gradient_accumulation=16,
gradient_clipping=1.0,
)
Loading…
Cancel
Save