mirror of https://github.com/hpcaitech/ColossalAI
[example] update ViT example using booster api (#3940)
parent
1aadeedeea
commit
b3ab7fbabf
|
@ -1,61 +1,28 @@
|
||||||
# Vision Transformer with ColoTensor
|
## Overview
|
||||||
|
|
||||||
# Overview
|
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
|
||||||
|
|
||||||
In this example, we will run Vision Transformer with ColoTensor.
|
In our example, we are using pretrained weights of ViT loaded from HuggingFace.
|
||||||
|
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
|
||||||
|
|
||||||
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test.
|
## Run Demo
|
||||||
You can change world size or decide whether use DDP in our code.
|
|
||||||
|
|
||||||
We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example.
|
By running the following script:
|
||||||
|
```bash
|
||||||
(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present.
|
bash run_demo.sh
|
||||||
|
|
||||||
# Requirement
|
|
||||||
|
|
||||||
Install colossalai version >= 0.1.11
|
|
||||||
|
|
||||||
## Unit test
|
|
||||||
To run unit test, you should install pytest, transformers with:
|
|
||||||
```shell
|
|
||||||
pip install pytest transformers
|
|
||||||
```
|
```
|
||||||
|
You will finetune a a [ViT-base](https://huggingface.co/google/vit-base-patch16-224) model on this [dataset](https://huggingface.co/datasets/beans), with more than 8000 images of bean leaves. This dataset is for image classification task and there are 3 labels: ['angular_leaf_spot', 'bean_rust', 'healthy'].
|
||||||
|
|
||||||
## Training example
|
The script can be modified if you want to try another set of hyperparameters or change to another ViT model with different size.
|
||||||
To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support.
|
|
||||||
You also need to install timm and titans for model/dataloader support with:
|
The demo code refers to this [blog](https://huggingface.co/blog/fine-tune-vit).
|
||||||
```shell
|
|
||||||
pip install timm titans
|
|
||||||
|
|
||||||
|
## Run Benchmark
|
||||||
|
|
||||||
|
You can run benchmark for ViT model by running the following script:
|
||||||
|
```bash
|
||||||
|
bash run_benchmark.sh
|
||||||
```
|
```
|
||||||
|
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
|
||||||
### Data preparation
|
|
||||||
You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one.
|
|
||||||
|
|
||||||
Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader.
|
|
||||||
```shell
|
|
||||||
export DATA=/path/to/ILSVRC2012
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
# How to run
|
|
||||||
|
|
||||||
## Unit test
|
|
||||||
In your terminal
|
|
||||||
```shell
|
|
||||||
pytest test_vit.py
|
|
||||||
```
|
|
||||||
|
|
||||||
This will evaluate models with different **world_size** and **use_ddp**.
|
|
||||||
|
|
||||||
## Training example
|
|
||||||
Modify the settings in run.sh according to your environment.
|
|
||||||
For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file,
|
|
||||||
data parallel size will be automatically calculated as 4.
|
|
||||||
Thus, the parallel strategy is set to 4DP+2TP.
|
|
||||||
|
|
||||||
Then in your terminal
|
|
||||||
```shell
|
|
||||||
sh run.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
This will start ViT-S training with ImageNet.
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
from colossalai import get_default_parser
|
||||||
|
|
||||||
|
def parse_demo_args():
|
||||||
|
|
||||||
|
parser = get_default_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default="google/vit-base-patch16-224",
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./output_model.bin",
|
||||||
|
help="The path of your saved model after finetuning."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_epoch",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of epochs."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size (per dp group) for the training dataloader."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=3e-4,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="Ratio of warmup steps against total training steps."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Weight decay to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="A seed for reproducible training."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def parse_benchmark_args():
|
||||||
|
|
||||||
|
parser = get_default_parser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default="google/vit-base-patch16-224",
|
||||||
|
help="Path to a pretrained model or model identifier from huggingface.co/models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per dp group) for the training dataloader."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_labels",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of labels for classification."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Weight decay to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Total number of training steps to perform."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="A seed for reproducible training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem_cap",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Limit on the usage of space for each GPU (in GB)."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
|
@ -1,32 +0,0 @@
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
|
|
||||||
# hyperparameters
|
|
||||||
# BATCH_SIZE is as per GPU
|
|
||||||
# global batch size = BATCH_SIZE x data parallel size
|
|
||||||
BATCH_SIZE = 256
|
|
||||||
LEARNING_RATE = 3e-3
|
|
||||||
WEIGHT_DECAY = 0.3
|
|
||||||
NUM_EPOCHS = 300
|
|
||||||
WARMUP_EPOCHS = 32
|
|
||||||
|
|
||||||
# model config
|
|
||||||
IMG_SIZE = 224
|
|
||||||
PATCH_SIZE = 16
|
|
||||||
HIDDEN_SIZE = 384
|
|
||||||
DEPTH = 12
|
|
||||||
NUM_HEADS = 6
|
|
||||||
MLP_RATIO = 4
|
|
||||||
NUM_CLASSES = 1000
|
|
||||||
CHECKPOINT = False
|
|
||||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
|
||||||
|
|
||||||
USE_DDP = True
|
|
||||||
TP_WORLD_SIZE = 2
|
|
||||||
TP_TYPE = 'row'
|
|
||||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
gradient_accumulation = 8
|
|
||||||
|
|
||||||
LOG_PATH = "./log"
|
|
|
@ -1,32 +0,0 @@
|
||||||
from colossalai.amp import AMP_TYPE
|
|
||||||
|
|
||||||
# hyperparameters
|
|
||||||
# BATCH_SIZE is as per GPU
|
|
||||||
# global batch size = BATCH_SIZE x data parallel size
|
|
||||||
BATCH_SIZE = 8
|
|
||||||
LEARNING_RATE = 3e-3
|
|
||||||
WEIGHT_DECAY = 0.3
|
|
||||||
NUM_EPOCHS = 3
|
|
||||||
WARMUP_EPOCHS = 1
|
|
||||||
|
|
||||||
# model config
|
|
||||||
IMG_SIZE = 224
|
|
||||||
PATCH_SIZE = 16
|
|
||||||
HIDDEN_SIZE = 32
|
|
||||||
DEPTH = 2
|
|
||||||
NUM_HEADS = 4
|
|
||||||
MLP_RATIO = 4
|
|
||||||
NUM_CLASSES = 10
|
|
||||||
CHECKPOINT = False
|
|
||||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
|
||||||
|
|
||||||
USE_DDP = True
|
|
||||||
TP_WORLD_SIZE = 2
|
|
||||||
TP_TYPE = 'row'
|
|
||||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
|
||||||
|
|
||||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
|
||||||
clip_grad_norm = 1.0
|
|
||||||
gradient_accumulation = 2
|
|
||||||
|
|
||||||
LOG_PATH = "./log_ci"
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
class BeansDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, image_processor, split='train'):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.ds = load_dataset('beans')[split]
|
||||||
|
self.label_names = self.ds.features['labels'].names
|
||||||
|
self.num_labels = len(self.label_names)
|
||||||
|
self.inputs = []
|
||||||
|
for example in self.ds:
|
||||||
|
self.inputs.append(self.process_example(example))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.inputs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.inputs[idx]
|
||||||
|
|
||||||
|
def process_example(self, example):
|
||||||
|
input = self.image_processor(example['image'], return_tensors='pt')
|
||||||
|
input['labels'] = example['labels']
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def beans_collator(batch):
|
||||||
|
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
|
||||||
|
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
|
|
@ -1,8 +1,6 @@
|
||||||
colossalai >= 0.1.12
|
colossalai >= 0.1.12
|
||||||
torch >= 1.8.1
|
torch >= 1.8.1
|
||||||
numpy>=1.24.1
|
numpy>=1.24.1
|
||||||
timm>=0.6.12
|
|
||||||
titans>=0.0.7
|
|
||||||
tqdm>=4.61.2
|
tqdm>=4.61.2
|
||||||
transformers>=4.25.1
|
transformers>=4.20.0
|
||||||
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
|
datasets
|
|
@ -1,15 +0,0 @@
|
||||||
export DATA=/data/scratch/imagenet/tf_records
|
|
||||||
export OMP_NUM_THREADS=4
|
|
||||||
|
|
||||||
# resume
|
|
||||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
|
|
||||||
# --nproc_per_node 4 train.py \
|
|
||||||
# --config configs/vit_1d_tp2.py \
|
|
||||||
# --resume_from checkpoint/epoch_10 \
|
|
||||||
# --master_port 29598 | tee ./out 2>&1
|
|
||||||
|
|
||||||
# train
|
|
||||||
CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
|
|
||||||
--nproc_per_node 4 train.py \
|
|
||||||
--config configs/vit_1d_tp2.py \
|
|
||||||
--master_port 29598 | tee ./out 2>&1
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
export BS=8
|
||||||
|
export MEMCAP=0
|
||||||
|
export GPUNUM=1
|
||||||
|
|
||||||
|
for BS in 8 32 128
|
||||||
|
do
|
||||||
|
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||||
|
do
|
||||||
|
for GPUNUM in 1 4
|
||||||
|
do
|
||||||
|
|
||||||
|
MODEL_PATH="google/vit-base-patch16-224"
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
vit_benchmark.py \
|
||||||
|
--model_name_or_path ${MODEL_PATH} \
|
||||||
|
--mem_cap ${MEMCAP} \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS}
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
|
@ -0,0 +1,44 @@
|
||||||
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# model name or path
|
||||||
|
MODEL="google/vit-base-patch16-224"
|
||||||
|
|
||||||
|
# path for saving model
|
||||||
|
OUTPUT_PATH="./output_model.bin"
|
||||||
|
|
||||||
|
# plugin(training strategy)
|
||||||
|
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
|
||||||
|
PLUGIN="gemini"
|
||||||
|
|
||||||
|
# number of gpus to use
|
||||||
|
GPUNUM=4
|
||||||
|
|
||||||
|
# batch size per gpu
|
||||||
|
BS=16
|
||||||
|
|
||||||
|
# learning rate
|
||||||
|
LR="2e-4"
|
||||||
|
|
||||||
|
# number of epoch
|
||||||
|
EPOCH=3
|
||||||
|
|
||||||
|
# weight decay
|
||||||
|
WEIGHT_DECAY=0.05
|
||||||
|
|
||||||
|
# ratio of warmup steps
|
||||||
|
WARMUP_RATIO=0.3
|
||||||
|
|
||||||
|
# run the script for demo
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
vit_train_demo.py \
|
||||||
|
--model_name_or_path ${MODEL} \
|
||||||
|
--output_path ${OUTPUT_PATH} \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS} \
|
||||||
|
--num_epoch ${EPOCH} \
|
||||||
|
--learning_rate ${LR} \
|
||||||
|
--weight_decay ${WEIGHT_DECAY} \
|
||||||
|
--warmup_ratio ${WARMUP_RATIO}
|
|
@ -1,9 +1,19 @@
|
||||||
export OMP_NUM_THREADS=4
|
set -xe
|
||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
# train
|
BS=8
|
||||||
colossalai run \
|
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||||
--nproc_per_node 4 train.py \
|
do
|
||||||
--config configs/vit_1d_tp2_ci.py \
|
for GPUNUM in 1 4
|
||||||
--dummy_data
|
do
|
||||||
|
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
vit_benchmark.py \
|
||||||
|
--model_name_or_path "google/vit-base-patch16-224" \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS}
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
|
@ -1,160 +0,0 @@
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from vit import get_training_components
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
|
||||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
from colossalai.zero import ColoInitContext
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_equal(A, B):
|
|
||||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
|
||||||
assert tensor.ndim == shard.ndim
|
|
||||||
if tensor.shape == shard.shape:
|
|
||||||
return tensor_equal(tensor, shard)
|
|
||||||
else:
|
|
||||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
|
||||||
if dims_not_eq.numel() == 1:
|
|
||||||
# 1D shard
|
|
||||||
dim = dims_not_eq.item()
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
|
|
||||||
# But for other layers, it's 1d_col split.
|
|
||||||
# Layernorm is not supported for now.
|
|
||||||
# patch_embeddings.projection has nn.Conv2d
|
|
||||||
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
|
|
||||||
def init_1d_row_for_linear_weight_spec(model, world_size: int):
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n:
|
|
||||||
p.set_process_group(pg)
|
|
||||||
p.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
# Similarly, it's col split for Linear but row split for others.
|
|
||||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if ('weight' in n
|
|
||||||
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n:
|
|
||||||
p.set_process_group(pg)
|
|
||||||
p.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model):
|
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
||||||
assert tensor_shard_equal(torch_p, p)
|
|
||||||
|
|
||||||
|
|
||||||
def check_grad_equal(model, torch_model):
|
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
|
||||||
if (torch_p.grad.shape == p.grad.shape):
|
|
||||||
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
|
|
||||||
else:
|
|
||||||
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
|
|
||||||
dim = dims_not_eq.item()
|
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
||||||
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True
|
|
||||||
|
|
||||||
|
|
||||||
def run_vit(init_spec_func, use_ddp):
|
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components()
|
|
||||||
with ColoInitContext(device=get_current_device()):
|
|
||||||
model = model_builder()
|
|
||||||
model = model.cuda()
|
|
||||||
torch_model = model_builder().cuda()
|
|
||||||
if use_ddp:
|
|
||||||
model = ColoDDP(model)
|
|
||||||
torch_model = DDP(torch_model,
|
|
||||||
device_ids=[gpc.get_global_rank()],
|
|
||||||
process_group=gpc.get_group(ParallelMode.DATA))
|
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
|
||||||
torch_p.data.copy_(p)
|
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
init_spec_func(model, world_size)
|
|
||||||
|
|
||||||
check_param_equal(model, torch_model)
|
|
||||||
model.train()
|
|
||||||
torch_model.train()
|
|
||||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
|
||||||
|
|
||||||
optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
||||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
||||||
|
|
||||||
for i, image_dict in enumerate(train_dataloader):
|
|
||||||
if use_ddp:
|
|
||||||
model.zero_grad()
|
|
||||||
else:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
logits = model(image_dict['pixel_values'])
|
|
||||||
torch_logits = torch_model(image_dict['pixel_values'])
|
|
||||||
assert tensor_equal(torch_logits.logits, logits.logits)
|
|
||||||
loss = criterion(logits.logits, image_dict['label'])
|
|
||||||
torch_loss = criterion(torch_logits.logits, image_dict['label'])
|
|
||||||
if use_ddp:
|
|
||||||
model.backward(loss)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
torch_loss.backward()
|
|
||||||
check_grad_equal(model, torch_model)
|
|
||||||
optimizer.step()
|
|
||||||
torch_optimizer.step()
|
|
||||||
check_param_equal(model, torch_model)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_ddp):
|
|
||||||
if use_ddp and world_size == 1:
|
|
||||||
return
|
|
||||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
|
|
||||||
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
|
||||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_vit(world_size, use_ddp):
|
|
||||||
spawn(run_dist, world_size, use_ddp=use_ddp)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_vit(1, False)
|
|
|
@ -1,174 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from timm.models.vision_transformer import _create_vision_transformer
|
|
||||||
from titans.dataloader.imagenet import build_dali_imagenet
|
|
||||||
from tqdm import tqdm
|
|
||||||
from vit import DummyDataLoader
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
||||||
from colossalai.nn import CrossEntropyLoss
|
|
||||||
from colossalai.nn._ops import *
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
|
||||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.zero import ColoInitContext
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_for_linear_weight_spec(model, world_size: int):
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n:
|
|
||||||
p.set_process_group(pg)
|
|
||||||
p.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
# Similarly, it's col split for Linear but row split for others.
|
|
||||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n
|
|
||||||
and 'patch_embed.proj.bias' not in n):
|
|
||||||
p.set_process_group(pg)
|
|
||||||
p.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def init_spec_func(model, tp_type):
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
if tp_type == 'row':
|
|
||||||
init_1d_row_for_linear_weight_spec(model, world_size)
|
|
||||||
elif tp_type == 'col':
|
|
||||||
init_1d_col_for_linear_weight_bias_spec(model, world_size)
|
|
||||||
else:
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
|
|
||||||
def train_imagenet():
|
|
||||||
|
|
||||||
parser = colossalai.get_default_parser()
|
|
||||||
parser.add_argument('--resume_from', default=False, action='store_true')
|
|
||||||
parser.add_argument('--dummy_data', default=False, action='store_true')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
colossalai.launch_from_torch(config=args.config)
|
|
||||||
use_ddp = gpc.config.USE_DDP
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
|
||||||
if hasattr(gpc.config, 'LOG_PATH'):
|
|
||||||
if gpc.get_global_rank() == 0:
|
|
||||||
log_path = gpc.config.LOG_PATH
|
|
||||||
if not os.path.exists(log_path):
|
|
||||||
os.mkdir(log_path)
|
|
||||||
logger.log_to_file(log_path)
|
|
||||||
|
|
||||||
logger.info('Build data loader', ranks=[0])
|
|
||||||
if not args.dummy_data:
|
|
||||||
root = os.environ['DATA']
|
|
||||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
|
||||||
train_batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
test_batch_size=gpc.config.BATCH_SIZE)
|
|
||||||
else:
|
|
||||||
train_dataloader = DummyDataLoader(length=10,
|
|
||||||
batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
category=gpc.config.NUM_CLASSES,
|
|
||||||
image_size=gpc.config.IMG_SIZE,
|
|
||||||
return_dict=False)
|
|
||||||
test_dataloader = DummyDataLoader(length=5,
|
|
||||||
batch_size=gpc.config.BATCH_SIZE,
|
|
||||||
category=gpc.config.NUM_CLASSES,
|
|
||||||
image_size=gpc.config.IMG_SIZE,
|
|
||||||
return_dict=False)
|
|
||||||
|
|
||||||
logger.info('Build model', ranks=[0])
|
|
||||||
|
|
||||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
|
||||||
patch_size=gpc.config.PATCH_SIZE,
|
|
||||||
embed_dim=gpc.config.HIDDEN_SIZE,
|
|
||||||
depth=gpc.config.DEPTH,
|
|
||||||
num_heads=gpc.config.NUM_HEADS,
|
|
||||||
mlp_ratio=gpc.config.MLP_RATIO,
|
|
||||||
num_classes=gpc.config.NUM_CLASSES,
|
|
||||||
drop_rate=0.1,
|
|
||||||
attn_drop_rate=0.1,
|
|
||||||
weight_init='jax')
|
|
||||||
|
|
||||||
with ColoInitContext(device=get_current_device()):
|
|
||||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs)
|
|
||||||
init_spec_func(model, gpc.config.TP_TYPE)
|
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size))
|
|
||||||
logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0])
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
|
||||||
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
|
||||||
total_steps=gpc.config.NUM_EPOCHS,
|
|
||||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
|
||||||
|
|
||||||
start_epoch = 0
|
|
||||||
if args.resume_from:
|
|
||||||
load_model = torch.load(args.resume_from + '_model.pth')
|
|
||||||
start_epoch = load_model['epoch']
|
|
||||||
model.load_state_dict(load_model['model'])
|
|
||||||
load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank()))
|
|
||||||
optimizer.load_state_dict(load_optim['optim'])
|
|
||||||
|
|
||||||
for epoch in range(start_epoch, gpc.config.NUM_EPOCHS):
|
|
||||||
model.train()
|
|
||||||
for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False):
|
|
||||||
x, y = x.cuda(), y.cuda()
|
|
||||||
output = model(x)
|
|
||||||
loss = criterion(output, y)
|
|
||||||
loss = loss / gpc.config.gradient_accumulation
|
|
||||||
if use_ddp:
|
|
||||||
model.backward(loss)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
if (index + 1) % gpc.config.gradient_accumulation == 0:
|
|
||||||
optimizer.step()
|
|
||||||
if use_ddp:
|
|
||||||
model.zero_grad()
|
|
||||||
else:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}",
|
|
||||||
ranks=[0])
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
test_loss = 0
|
|
||||||
correct = 0
|
|
||||||
test_sum = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False):
|
|
||||||
x, y = x.cuda(), y.cuda()
|
|
||||||
output = model(x)
|
|
||||||
test_loss += F.cross_entropy(output, y, reduction='sum').item()
|
|
||||||
pred = output.argmax(dim=1, keepdim=True)
|
|
||||||
correct += pred.eq(y.view_as(pred)).sum().item()
|
|
||||||
test_sum += y.size(0)
|
|
||||||
|
|
||||||
test_loss /= test_sum
|
|
||||||
logger.info(
|
|
||||||
f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})",
|
|
||||||
ranks=[0])
|
|
||||||
|
|
||||||
lr_scheduler.step()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
train_imagenet()
|
|
|
@ -1,95 +0,0 @@
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import ViTConfig, ViTForImageClassification
|
|
||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDataGenerator(ABC):
|
|
||||||
|
|
||||||
def __init__(self, length=10):
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self.step = 0
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.step < self.length:
|
|
||||||
self.step += 1
|
|
||||||
return self.generate()
|
|
||||||
else:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDataLoader(DummyDataGenerator):
|
|
||||||
|
|
||||||
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
|
|
||||||
super().__init__(length)
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.channel = channel
|
|
||||||
self.category = category
|
|
||||||
self.image_size = image_size
|
|
||||||
self.return_dict = return_dict
|
|
||||||
|
|
||||||
def generate(self):
|
|
||||||
image_dict = {}
|
|
||||||
image_dict['pixel_values'] = torch.rand(
|
|
||||||
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
|
|
||||||
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=get_current_device())
|
|
||||||
if not self.return_dict:
|
|
||||||
return image_dict['pixel_values'], image_dict['label']
|
|
||||||
return image_dict
|
|
||||||
|
|
||||||
|
|
||||||
class ViTCVModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
hidden_size=768,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
num_attention_heads=12,
|
|
||||||
image_size=224,
|
|
||||||
patch_size=16,
|
|
||||||
num_channels=3,
|
|
||||||
num_labels=8,
|
|
||||||
checkpoint=False):
|
|
||||||
super().__init__()
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.model = ViTForImageClassification(
|
|
||||||
ViTConfig(hidden_size=hidden_size,
|
|
||||||
num_hidden_layers=num_hidden_layers,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
image_size=image_size,
|
|
||||||
patch_size=patch_size,
|
|
||||||
num_channels=num_channels,
|
|
||||||
num_labels=num_labels))
|
|
||||||
if checkpoint:
|
|
||||||
self.model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
return self.model(pixel_values=pixel_values)
|
|
||||||
|
|
||||||
|
|
||||||
def vit_base_s(checkpoint=True):
|
|
||||||
return ViTCVModel(checkpoint=checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
def vit_base_micro(checkpoint=True):
|
|
||||||
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_components():
|
|
||||||
trainloader = DummyDataLoader()
|
|
||||||
testloader = DummyDataLoader()
|
|
||||||
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import ViTConfig, ViTForImageClassification
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
from args import parse_benchmark_args
|
||||||
|
|
||||||
|
def format_num(num: int, bytes=False):
|
||||||
|
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
||||||
|
factor = 1024 if bytes else 1000
|
||||||
|
suffix = "B" if bytes else ""
|
||||||
|
for unit in ["", " K", " M", " G", " T", " P"]:
|
||||||
|
if num < factor:
|
||||||
|
return f"{num:.2f}{unit}{suffix}"
|
||||||
|
num /= factor
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||||
|
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float)
|
||||||
|
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||||
|
return pixel_values, labels
|
||||||
|
|
||||||
|
|
||||||
|
def colo_memory_cap(size_in_GB):
|
||||||
|
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||||
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
|
if size_in_GB * (1024**3) < cuda_capacity:
|
||||||
|
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||||
|
print(f"Limiting GPU memory usage to {size_in_GB} GB")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = parse_benchmark_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
world_size = coordinator.world_size
|
||||||
|
|
||||||
|
# Manage loggers
|
||||||
|
disable_existing_loggers()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if coordinator.is_master():
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Whether to set limit on memory capacity
|
||||||
|
if args.mem_cap > 0:
|
||||||
|
colo_memory_cap(args.mem_cap)
|
||||||
|
|
||||||
|
# Build ViT model
|
||||||
|
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
model = ViTForImageClassification(config)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||||||
|
placement_policy='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
strict_ddp_mode=True,
|
||||||
|
initial_scale=2**5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Set optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||||
|
|
||||||
|
|
||||||
|
# Start training.
|
||||||
|
logger.info(f"Start testing", ranks=[0])
|
||||||
|
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _ in range(args.max_train_steps):
|
||||||
|
|
||||||
|
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(pixel_values=pixel_values, labels=labels)
|
||||||
|
loss = outputs['loss']
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
# Compute Statistics
|
||||||
|
end_time = time.time()
|
||||||
|
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||||
|
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||||
|
|
||||||
|
logger.info(f"Testing finished, "
|
||||||
|
f"batch size per gpu: {args.batch_size}, "
|
||||||
|
f"plugin: {args.plugin}, "
|
||||||
|
f"throughput: {throughput}, "
|
||||||
|
f"maximum memory usage per gpu: {max_mem}.",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,177 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import transformers
|
||||||
|
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
from args import parse_demo_args
|
||||||
|
from data import BeansDataset, beans_collator
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_cuda(batch, device):
|
||||||
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||||
|
|
||||||
|
for batch in pbar:
|
||||||
|
|
||||||
|
# Foward
|
||||||
|
optimizer.zero_grad()
|
||||||
|
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs['loss']
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# Print batch loss
|
||||||
|
pbar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
total_num = torch.zeros(1, device=get_current_device())
|
||||||
|
accum_correct = torch.zeros(1, device=get_current_device())
|
||||||
|
|
||||||
|
for batch in eval_dataloader:
|
||||||
|
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||||
|
outputs = model(**batch)
|
||||||
|
val_loss, logits = outputs[:2]
|
||||||
|
accum_loss += (val_loss / len(eval_dataloader))
|
||||||
|
if num_labels > 1:
|
||||||
|
preds = torch.argmax(logits, dim=1)
|
||||||
|
elif num_labels == 1:
|
||||||
|
preds = logits.squeeze()
|
||||||
|
|
||||||
|
labels = batch["labels"]
|
||||||
|
total_num += batch["labels"].shape[0]
|
||||||
|
accum_correct += (torch.sum(preds == labels))
|
||||||
|
|
||||||
|
dist.all_reduce(accum_loss)
|
||||||
|
dist.all_reduce(total_num)
|
||||||
|
dist.all_reduce(accum_correct)
|
||||||
|
avg_loss = "{:.4f}".format(accum_loss.item())
|
||||||
|
accuracy = "{:.4f}".format(accum_correct.item() / total_num.item())
|
||||||
|
if coordinator.is_master():
|
||||||
|
print(f"Evaluation result for epoch {epoch + 1}: \
|
||||||
|
average_loss={avg_loss}, \
|
||||||
|
accuracy={accuracy}.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = parse_demo_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
world_size = coordinator.world_size
|
||||||
|
|
||||||
|
# Manage loggers
|
||||||
|
disable_existing_loggers()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if coordinator.is_master():
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Prepare Dataset
|
||||||
|
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
|
||||||
|
train_dataset = BeansDataset(image_processor, split='train')
|
||||||
|
eval_dataset = BeansDataset(image_processor, split='validation')
|
||||||
|
|
||||||
|
|
||||||
|
# Load pretrained ViT model
|
||||||
|
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
config.num_labels = train_dataset.num_labels
|
||||||
|
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
|
||||||
|
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
|
||||||
|
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
ignore_mismatched_sizes=True)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||||||
|
placement_policy='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
strict_ddp_mode=True,
|
||||||
|
initial_scale=2**5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Prepare dataloader
|
||||||
|
train_dataloader = plugin.prepare_dataloader(train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=beans_collator)
|
||||||
|
eval_dataloader = plugin.prepare_dataloader(eval_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=beans_collator)
|
||||||
|
|
||||||
|
# Set optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
# Set lr scheduler
|
||||||
|
total_steps = len(train_dataloader) * args.num_epoch
|
||||||
|
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||||
|
total_steps=(len(train_dataloader) * args.num_epoch),
|
||||||
|
warmup_steps=num_warmup_steps)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
|
# Finetuning
|
||||||
|
logger.info(f"Start finetuning", ranks=[0])
|
||||||
|
for epoch in range(args.num_epoch):
|
||||||
|
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
||||||
|
evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
|
||||||
|
logger.info(f"Finish finetuning", ranks=[0])
|
||||||
|
|
||||||
|
# Save the finetuned model
|
||||||
|
booster.save_model(model, args.output_path)
|
||||||
|
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -67,17 +67,8 @@ def main():
|
||||||
colo_memory_cap(args.mem_cap)
|
colo_memory_cap(args.mem_cap)
|
||||||
|
|
||||||
# Build OPT model
|
# Build OPT model
|
||||||
# Initialize the model under ColoInitContext if using GeminiPlugin
|
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
if args.plugin == 'gemini':
|
model = OPTForCausalLM(config=config)
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
default_dist_spec = ShardSpec([-1], [world_size])
|
|
||||||
with ColoInitContext(device='cpu',
|
|
||||||
default_dist_spec=default_dist_spec,
|
|
||||||
default_pg=shard_pg):
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
# Enable gradient checkpointing
|
||||||
|
|
|
@ -74,17 +74,8 @@ def main():
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
# Build OPT model
|
# Build OPT model
|
||||||
# Initialize the model under ColoInitContext if using GeminiPlugin
|
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
if args.plugin == 'gemini':
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size)
|
|
||||||
default_dist_spec = ShardSpec([-1], [world_size])
|
|
||||||
with ColoInitContext(device='cpu',
|
|
||||||
default_dist_spec=default_dist_spec,
|
|
||||||
default_pg=shard_pg):
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
# Enable gradient checkpointing
|
||||||
|
@ -116,7 +107,9 @@ def main():
|
||||||
collate_fn=netflix_collator)
|
collate_fn=netflix_collator)
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
optimizer = HybridAdam(model.parameters(),
|
||||||
|
lr=(args.learning_rate * world_size),
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
# Set lr scheduler
|
# Set lr scheduler
|
||||||
total_steps = len(dataloader) * args.num_epoch
|
total_steps = len(dataloader) * args.num_epoch
|
||||||
|
|
Loading…
Reference in New Issue