ColossalAI/examples/vit-b16
Frank Lee da01c234e1
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 15:08:29 +08:00
..
dataloader add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) 2021-11-18 23:45:09 +08:00
README.md Develop/experiments (#59) 2021-12-09 15:08:29 +08:00
acc.jpeg add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) 2021-11-18 23:45:09 +08:00
hooks.py add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) 2021-11-18 23:45:09 +08:00
loss.jpeg add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) 2021-11-18 23:45:09 +08:00
mixup.py add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) 2021-11-18 23:45:09 +08:00
train_dali.py Develop/experiments (#59) 2021-12-09 15:08:29 +08:00
vit-b16.py Develop/experiments (#59) 2021-12-09 15:08:29 +08:00

README.md

Overview

Here is an example of training ViT-B/16 on Imagenet-1K with batch size 32K. We use 8x NVIDIA A100 GPU in this example.

How to run

Using Slurm:

srun python train_dali.py --local_rank=$SLURM_PROCID --world_size=$SLURM_NPROCS --host=$HOST --port=29500 --config=vit-b16.py

Results

Loss Curve Accuracy

Details

vit-b16.py

It is a config file, which is used by ColossalAI to define all kinds of training arguments, such as the model, dataset, and training method (optimizer, lr_scheduler, epoch, etc.). You can access config content by gpc.config.

In this example, we train the ViT-Base patch 16 model 300 epochs on ImageNet-1K. The batch size is set to 32K through data parallel (4K on each GPU from 16x gradient accumulation with batch size 256). Since the batch size is very large than common usage, leading to convergence difficulties, we use a large batch optimizer LAMB, and we can scale the batch size to 32K with a little accuracy loss. The learning rate and weight decay of the optimizer are set to 1.8e-2 and 0.1, respectively. We use a linear warmup learning rate scheduler and warmup 150 epochs. We introduce FP16 mixed precision to accelerate training and use gradient clipping to help convergence. For simplicity and speed, we didn't apply RandAug and just used Mixup in data augmentation.

If you have enough computing resources, you can expand this example conveniently with data parallel on a very large scale without gradient accumulation, and finish the training process even within one hour.

imagenet_dali_dataloader.py To accelerate the training process, we use DALI as data loader. Note that it requires the dataset in TFRecord format, avoiding read raw images which reduces efficiency of the file system.

train_dali.py We build the DALI data loader and train process using Colossal-AI here.

mixup.py Since we used Mixup, we define mixup loss in this file.

hooks.py We also define useful hooks to log information help debugging.