diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index 68531e92a..f7248bd68 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -94,7 +94,7 @@ class Lamb(Optimizer): # * math.sqrt(bias_correction2) / bias_correction1 step_size = group['lr'] - weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + weight_norm = p.data.pow(2).sum().sqrt() adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) if group['weight_decay'] != 0: diff --git a/examples/vit-b16/README.md b/examples/vit-b16/README.md new file mode 100644 index 000000000..83b924c2e --- /dev/null +++ b/examples/vit-b16/README.md @@ -0,0 +1,14 @@ +# Overview + +Here is an example of training ViT-B/16 on Imagenet-1K. We use 8x A100 in this example. For simplicity and speed, we didn't apply `RandAug` and we just used `Mixup`. With `LAMB` optimizer, we can scale the batch size to 32K with a little accuracy loss. + +# How to run +Using slurm: +```shell +srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py +``` + +# Results + +![Loss Curve](./loss.jpeg) +![Accuracy](./acc.jpeg) diff --git a/examples/vit-b16/acc.jpeg b/examples/vit-b16/acc.jpeg new file mode 100755 index 000000000..43f67fd39 Binary files /dev/null and b/examples/vit-b16/acc.jpeg differ diff --git a/examples/vit-b16/dataloader/__init__.py b/examples/vit-b16/dataloader/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/examples/vit-b16/dataloader/imagenet_dali_dataloader.py b/examples/vit-b16/dataloader/imagenet_dali_dataloader.py new file mode 100755 index 000000000..a39d73e26 --- /dev/null +++ b/examples/vit-b16/dataloader/imagenet_dali_dataloader.py @@ -0,0 +1,112 @@ +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali.tfrecord as tfrec +import torch +import numpy as np + + +class DaliDataloader(DALIClassificationIterator): + def __init__(self, + tfrec_filenames, + tfrec_idx_filenames, + shard_id=0, + num_shards=1, + batch_size=128, + num_threads=4, + resize=256, + crop=224, + prefetch=2, + training=True, + gpu_aug=False, + cuda=True, + mixup_alpha=0.0): + self.mixup_alpha = mixup_alpha + self.training = training + pipe = Pipeline(batch_size=batch_size, + num_threads=num_threads, + device_id=torch.cuda.current_device() if cuda else None, + seed=1024) + with pipe: + inputs = fn.readers.tfrecord( + path=tfrec_filenames, + index_path=tfrec_idx_filenames, + random_shuffle=training, + shard_id=shard_id, + num_shards=num_shards, + initial_fill=10000, + read_ahead=True, + prefetch_queue_depth=prefetch, + name='Reader', + features={ + 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""), + 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), + }) + images = inputs["image/encoded"] + + if training: + images = fn.decoders.image(images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.random_resized_crop(images, + size=crop, + device='gpu' if gpu_aug else 'cpu') + flip_lr = fn.random.coin_flip(probability=0.5) + else: + # decode jpeg and resize + images = fn.decoders.image(images, + device='mixed' if gpu_aug else 'cpu', + output_type=types.RGB) + images = fn.resize(images, + device='gpu' if gpu_aug else 'cpu', + resize_x=resize, + resize_y=resize, + dtype=types.FLOAT, + interp_type=types.INTERP_TRIANGULAR) + flip_lr = False + + # center crop and normalise + images = fn.crop_mirror_normalize(images, + dtype=types.FLOAT, + crop=(crop, crop), + mean=[127.5], + std=[127.5], + mirror=flip_lr) + label = inputs["image/class/label"] - 1 # 0-999 + # LSG: element_extract will raise exception, let's flatten outside + # label = fn.element_extract(label, element_map=0) # Flatten + if cuda: # transfer data to gpu + pipe.set_outputs(images.gpu(), label.gpu()) + else: + pipe.set_outputs(images, label) + + pipe.build() + last_batch_policy = 'DROP' if training else 'PARTIAL' + super().__init__(pipe, reader_name="Reader", + auto_reset=True, + last_batch_policy=last_batch_policy) + + def __iter__(self): + # if not reset (after an epoch), reset; if just initialize, ignore + if self._counter >= self._size or self._size < 0: + self.reset() + return self + + def __next__(self): + data = super().__next__() + img, label = data[0]['data'], data[0]['label'] + label = label.squeeze() + if self.mixup_alpha > 0.0: + if self.training: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + idx = torch.randperm(img.size(0)).to(img.device) + img = lam * img + (1 - lam) * img[idx, :] + label_a, label_b = label, label[idx] + lam = torch.tensor([lam], device=img.device, dtype=img.dtype) + label = (label_a, label_b, lam) + else: + label = (label, label, torch.ones( + 1, device=img.device, dtype=img.dtype)) + return (img,), label + return (img,), (label,) diff --git a/examples/vit-b16/hooks.py b/examples/vit-b16/hooks.py new file mode 100644 index 000000000..b6c306ed7 --- /dev/null +++ b/examples/vit-b16/hooks.py @@ -0,0 +1,15 @@ +from colossalai.registry import HOOKS +from colossalai.trainer import BaseHook +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +@HOOKS.register_module +class TotalBatchsizeHook(BaseHook): + def __init__(self, trainer, priority: int = 2) -> None: + super().__init__(trainer, priority) + + def before_train(self): + total_batch_size = gpc.config.BATCH_SIZE * \ + gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA) + self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0]) diff --git a/examples/vit-b16/loss.jpeg b/examples/vit-b16/loss.jpeg new file mode 100755 index 000000000..a16c333cc Binary files /dev/null and b/examples/vit-b16/loss.jpeg differ diff --git a/examples/vit-b16/mixup.py b/examples/vit-b16/mixup.py new file mode 100644 index 000000000..822bc8659 --- /dev/null +++ b/examples/vit-b16/mixup.py @@ -0,0 +1,12 @@ +import torch.nn as nn +from colossalai.registry import LOSSES + +@LOSSES.register_module +class MixupLoss(nn.Module): + def __init__(self, loss_fn_cls): + super().__init__() + self.loss_fn = loss_fn_cls() + + def forward(self, inputs, *args): + targets_a, targets_b, lam = args + return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) diff --git a/examples/vit-b16/train_dali.py b/examples/vit-b16/train_dali.py new file mode 100644 index 000000000..fed39c3cc --- /dev/null +++ b/examples/vit-b16/train_dali.py @@ -0,0 +1,70 @@ +import glob +import os +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_global_dist_logger +from colossalai.trainer import Trainer +from colossalai.utils import set_global_multitimer_status +from dataloader.imagenet_dali_dataloader import DaliDataloader + + +def build_dali_train(): + root = gpc.config.dali.root + train_pat = os.path.join(root, 'train/*') + train_idx_pat = os.path.join(root, 'idx_files/train/*') + return DaliDataloader( + sorted(glob.glob(train_pat)), + sorted(glob.glob(train_idx_pat)), + batch_size=gpc.config.BATCH_SIZE, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=True, + gpu_aug=gpc.config.dali.gpu_aug, + cuda=True, + mixup_alpha=gpc.config.dali.mixup_alpha + ) + + +def build_dali_test(): + root = gpc.config.dali.root + val_pat = os.path.join(root, 'validation/*') + val_idx_pat = os.path.join(root, 'idx_files/validation/*') + return DaliDataloader( + sorted(glob.glob(val_pat)), + sorted(glob.glob(val_idx_pat)), + batch_size=gpc.config.BATCH_SIZE, + shard_id=gpc.get_local_rank(ParallelMode.DATA), + num_shards=gpc.get_world_size(ParallelMode.DATA), + training=False, + # gpu_aug=gpc.config.dali.gpu_aug, + gpu_aug=False, + cuda=True, + mixup_alpha=gpc.config.dali.mixup_alpha + ) + + +def main(): + engine, train_dataloader, test_dataloader = colossalai.initialize( + train_dataloader=build_dali_train, + test_dataloader=build_dali_test + ) + logger = get_global_dist_logger() + set_global_multitimer_status(True) + timer = colossalai.utils.get_global_multitimer() + trainer = Trainer(engine=engine, + verbose=True, + timer=timer) + + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=gpc.config.NUM_EPOCHS, + hooks_cfg=gpc.config.hooks, + display_progress=True, + test_interval=1 + ) + + +if __name__ == '__main__': + main() diff --git a/examples/vit-b16/vit-b16.py b/examples/vit-b16/vit-b16.py new file mode 100755 index 000000000..ac51e226e --- /dev/null +++ b/examples/vit-b16/vit-b16.py @@ -0,0 +1,78 @@ +from colossalai.engine import AMP_TYPE +from torch.nn import CrossEntropyLoss +from mixup import MixupLoss +from hooks import TotalBatchsizeHook +from colossalai.registry import MODELS +from timm.models import vit_base_patch16_224 + +MODELS.register_module(vit_base_patch16_224) + +LOG_NAME = 'vit-b16-1k-32k-mixup-light2' +# ViT Base +BATCH_SIZE = 256 +DROP_RATE = 0.1 +NUM_EPOCHS = 300 + +parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), +) + +optimizer = dict( + type='Lamb', + lr=1.8e-2, + weight_decay=0.1, +) + + +loss = dict( + type='MixupLoss', + loss_fn_cls=CrossEntropyLoss +) + +model = dict( + type='vit_base_patch16_224', + drop_rate=DROP_RATE, +) + +hooks = [ + dict(type='LogMetricByEpochHook'), + dict(type='AccuracyHook'), + dict(type='LossHook'), + dict(type='TotalBatchsizeHook'), + dict(type='TensorboardHook', log_dir=f'./tb_logs/{LOG_NAME}'), + dict(type='SaveCheckpointHook', interval=1, + checkpoint_dir=f'./ckpt/{LOG_NAME}'), + # dict(type='LoadCheckpointHook', epoch=10, + # checkpoint_dir=f'./ckpt/{LOG_NAME}'), + dict( + type='LRSchedulerHook', + by_epoch=True, + lr_scheduler_cfg=dict( + type='LinearWarmupLR', + warmup_steps=150 + ) + ), +] + +fp16 = dict( + mode=AMP_TYPE.TORCH, +) + + +logging = dict( + root_path=f"./logs/{LOG_NAME}" +) + +dali = dict( + root='./dataset/ILSVRC2012_1k', + gpu_aug=True, + mixup_alpha=0.2 +) + +engine = dict( + schedule=None, + gradient_handlers=None, + gradient_accumulation=16, + gradient_clipping=1.0, +)