ColossalAI/examples/images/vit
Hongxin Liu 27061426f7
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
2023-08-24 09:29:25 +08:00
..
README.md [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
args.py [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
data.py [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
requirements.txt [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
run_benchmark.sh [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
run_demo.sh [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
test_ci.sh [example] update ViT example using booster api (#3940) 2023-06-12 15:02:27 +08:00
vit_benchmark.py [gemini] improve compatibility and add static placement policy (#4479) 2023-08-24 09:29:25 +08:00
vit_train_demo.py [gemini] improve compatibility and add static placement policy (#4479) 2023-08-24 09:29:25 +08:00

README.md

Overview

Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale and achieved SOTA results on various tasks at that time.

In our example, we are using pretrained weights of ViT loaded from HuggingFace. We adapt the ViT training code to ColossalAI by leveraging Boosting API loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.

Run Demo

By running the following script:

bash run_demo.sh

You will finetune a a ViT-base model on this dataset, with more than 8000 images of bean leaves. This dataset is for image classification task and there are 3 labels: ['angular_leaf_spot', 'bean_rust', 'healthy'].

The script can be modified if you want to try another set of hyperparameters or change to another ViT model with different size.

The demo code refers to this blog.

Run Benchmark

You can run benchmark for ViT model by running the following script:

bash run_benchmark.sh

The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.