mirror of https://github.com/hpcaitech/ColossalAI
Baizhou Zhang
1 year ago
committed by
GitHub
17 changed files with 577 additions and 593 deletions
@ -1,61 +1,28 @@
|
||||
# Vision Transformer with ColoTensor |
||||
## Overview |
||||
|
||||
# Overview |
||||
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time. |
||||
|
||||
In this example, we will run Vision Transformer with ColoTensor. |
||||
In our example, we are using pretrained weights of ViT loaded from HuggingFace. |
||||
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. |
||||
|
||||
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test. |
||||
You can change world size or decide whether use DDP in our code. |
||||
## Run Demo |
||||
|
||||
We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example. |
||||
|
||||
(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present. |
||||
|
||||
# Requirement |
||||
|
||||
Install colossalai version >= 0.1.11 |
||||
|
||||
## Unit test |
||||
To run unit test, you should install pytest, transformers with: |
||||
```shell |
||||
pip install pytest transformers |
||||
By running the following script: |
||||
```bash |
||||
bash run_demo.sh |
||||
``` |
||||
You will finetune a a [ViT-base](https://huggingface.co/google/vit-base-patch16-224) model on this [dataset](https://huggingface.co/datasets/beans), with more than 8000 images of bean leaves. This dataset is for image classification task and there are 3 labels: ['angular_leaf_spot', 'bean_rust', 'healthy']. |
||||
|
||||
## Training example |
||||
To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support. |
||||
You also need to install timm and titans for model/dataloader support with: |
||||
```shell |
||||
pip install timm titans |
||||
``` |
||||
The script can be modified if you want to try another set of hyperparameters or change to another ViT model with different size. |
||||
|
||||
### Data preparation |
||||
You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one. |
||||
The demo code refers to this [blog](https://huggingface.co/blog/fine-tune-vit). |
||||
|
||||
Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader. |
||||
```shell |
||||
export DATA=/path/to/ILSVRC2012 |
||||
``` |
||||
|
||||
|
||||
# How to run |
||||
## Run Benchmark |
||||
|
||||
## Unit test |
||||
In your terminal |
||||
```shell |
||||
pytest test_vit.py |
||||
You can run benchmark for ViT model by running the following script: |
||||
```bash |
||||
bash run_benchmark.sh |
||||
``` |
||||
|
||||
This will evaluate models with different **world_size** and **use_ddp**. |
||||
|
||||
## Training example |
||||
Modify the settings in run.sh according to your environment. |
||||
For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file, |
||||
data parallel size will be automatically calculated as 4. |
||||
Thus, the parallel strategy is set to 4DP+2TP. |
||||
|
||||
Then in your terminal |
||||
```shell |
||||
sh run.sh |
||||
``` |
||||
|
||||
This will start ViT-S training with ImageNet. |
||||
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. |
@ -0,0 +1,124 @@
|
||||
from colossalai import get_default_parser |
||||
|
||||
def parse_demo_args(): |
||||
|
||||
parser = get_default_parser() |
||||
parser.add_argument( |
||||
"--model_name_or_path", |
||||
type=str, |
||||
default="google/vit-base-patch16-224", |
||||
help="Path to pretrained model or model identifier from huggingface.co/models." |
||||
) |
||||
parser.add_argument( |
||||
"--output_path", |
||||
type=str, |
||||
default="./output_model.bin", |
||||
help="The path of your saved model after finetuning." |
||||
) |
||||
parser.add_argument( |
||||
"--plugin", |
||||
type=str, |
||||
default="gemini", |
||||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." |
||||
) |
||||
parser.add_argument( |
||||
"--num_epoch", |
||||
type=int, |
||||
default=3, |
||||
help="Number of epochs." |
||||
) |
||||
parser.add_argument( |
||||
"--batch_size", |
||||
type=int, |
||||
default=32, |
||||
help="Batch size (per dp group) for the training dataloader." |
||||
) |
||||
parser.add_argument( |
||||
"--learning_rate", |
||||
type=float, |
||||
default=3e-4, |
||||
help="Initial learning rate (after the potential warmup period) to use." |
||||
) |
||||
parser.add_argument( |
||||
"--warmup_ratio", |
||||
type=float, |
||||
default=0.3, |
||||
help="Ratio of warmup steps against total training steps." |
||||
) |
||||
parser.add_argument( |
||||
"--weight_decay", |
||||
type=float, |
||||
default=0.1, |
||||
help="Weight decay to use." |
||||
) |
||||
parser.add_argument( |
||||
"--seed", |
||||
type=int, |
||||
default=42, |
||||
help="A seed for reproducible training." |
||||
) |
||||
|
||||
args = parser.parse_args() |
||||
return args |
||||
|
||||
def parse_benchmark_args(): |
||||
|
||||
parser = get_default_parser() |
||||
|
||||
parser.add_argument( |
||||
"--model_name_or_path", |
||||
type=str, |
||||
default="google/vit-base-patch16-224", |
||||
help="Path to a pretrained model or model identifier from huggingface.co/models." |
||||
) |
||||
parser.add_argument( |
||||
"--plugin", |
||||
type=str, |
||||
default="gemini", |
||||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." |
||||
) |
||||
parser.add_argument( |
||||
"--batch_size", |
||||
type=int, |
||||
default=8, |
||||
help="Batch size (per dp group) for the training dataloader." |
||||
) |
||||
parser.add_argument( |
||||
"--num_labels", |
||||
type=int, |
||||
default=10, |
||||
help="Number of labels for classification." |
||||
) |
||||
parser.add_argument( |
||||
"--learning_rate", |
||||
type=float, |
||||
default=5e-5, |
||||
help="Initial learning rate (after the potential warmup period) to use." |
||||
) |
||||
parser.add_argument( |
||||
"--weight_decay", |
||||
type=float, |
||||
default=0.0, |
||||
help="Weight decay to use." |
||||
) |
||||
parser.add_argument( |
||||
"--max_train_steps", |
||||
type=int, |
||||
default=20, |
||||
help="Total number of training steps to perform." |
||||
) |
||||
parser.add_argument( |
||||
"--seed", |
||||
type=int, |
||||
default=42, |
||||
help="A seed for reproducible training." |
||||
) |
||||
parser.add_argument( |
||||
"--mem_cap", |
||||
type=int, |
||||
default=0, |
||||
help="Limit on the usage of space for each GPU (in GB)." |
||||
) |
||||
args = parser.parse_args() |
||||
|
||||
return args |
@ -1,32 +0,0 @@
|
||||
from colossalai.amp import AMP_TYPE |
||||
|
||||
# hyperparameters |
||||
# BATCH_SIZE is as per GPU |
||||
# global batch size = BATCH_SIZE x data parallel size |
||||
BATCH_SIZE = 256 |
||||
LEARNING_RATE = 3e-3 |
||||
WEIGHT_DECAY = 0.3 |
||||
NUM_EPOCHS = 300 |
||||
WARMUP_EPOCHS = 32 |
||||
|
||||
# model config |
||||
IMG_SIZE = 224 |
||||
PATCH_SIZE = 16 |
||||
HIDDEN_SIZE = 384 |
||||
DEPTH = 12 |
||||
NUM_HEADS = 6 |
||||
MLP_RATIO = 4 |
||||
NUM_CLASSES = 1000 |
||||
CHECKPOINT = False |
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token |
||||
|
||||
USE_DDP = True |
||||
TP_WORLD_SIZE = 2 |
||||
TP_TYPE = 'row' |
||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) |
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE) |
||||
clip_grad_norm = 1.0 |
||||
gradient_accumulation = 8 |
||||
|
||||
LOG_PATH = "./log" |
@ -1,32 +0,0 @@
|
||||
from colossalai.amp import AMP_TYPE |
||||
|
||||
# hyperparameters |
||||
# BATCH_SIZE is as per GPU |
||||
# global batch size = BATCH_SIZE x data parallel size |
||||
BATCH_SIZE = 8 |
||||
LEARNING_RATE = 3e-3 |
||||
WEIGHT_DECAY = 0.3 |
||||
NUM_EPOCHS = 3 |
||||
WARMUP_EPOCHS = 1 |
||||
|
||||
# model config |
||||
IMG_SIZE = 224 |
||||
PATCH_SIZE = 16 |
||||
HIDDEN_SIZE = 32 |
||||
DEPTH = 2 |
||||
NUM_HEADS = 4 |
||||
MLP_RATIO = 4 |
||||
NUM_CLASSES = 10 |
||||
CHECKPOINT = False |
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token |
||||
|
||||
USE_DDP = True |
||||
TP_WORLD_SIZE = 2 |
||||
TP_TYPE = 'row' |
||||
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),) |
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE) |
||||
clip_grad_norm = 1.0 |
||||
gradient_accumulation = 2 |
||||
|
||||
LOG_PATH = "./log_ci" |
@ -0,0 +1,32 @@
|
||||
import torch |
||||
from torch.utils.data import Dataset |
||||
from datasets import load_dataset |
||||
|
||||
class BeansDataset(Dataset): |
||||
|
||||
def __init__(self, image_processor, split='train'): |
||||
|
||||
super().__init__() |
||||
self.image_processor = image_processor |
||||
self.ds = load_dataset('beans')[split] |
||||
self.label_names = self.ds.features['labels'].names |
||||
self.num_labels = len(self.label_names) |
||||
self.inputs = [] |
||||
for example in self.ds: |
||||
self.inputs.append(self.process_example(example)) |
||||
|
||||
def __len__(self): |
||||
return len(self.inputs) |
||||
|
||||
def __getitem__(self, idx): |
||||
return self.inputs[idx] |
||||
|
||||
def process_example(self, example): |
||||
input = self.image_processor(example['image'], return_tensors='pt') |
||||
input['labels'] = example['labels'] |
||||
return input |
||||
|
||||
|
||||
def beans_collator(batch): |
||||
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), |
||||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)} |
@ -1,8 +1,6 @@
|
||||
colossalai >= 0.1.12 |
||||
torch >= 1.8.1 |
||||
numpy>=1.24.1 |
||||
timm>=0.6.12 |
||||
titans>=0.0.7 |
||||
tqdm>=4.61.2 |
||||
transformers>=4.25.1 |
||||
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist |
||||
transformers>=4.20.0 |
||||
datasets |
@ -1,15 +0,0 @@
|
||||
export DATA=/data/scratch/imagenet/tf_records |
||||
export OMP_NUM_THREADS=4 |
||||
|
||||
# resume |
||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ |
||||
# --nproc_per_node 4 train.py \ |
||||
# --config configs/vit_1d_tp2.py \ |
||||
# --resume_from checkpoint/epoch_10 \ |
||||
# --master_port 29598 | tee ./out 2>&1 |
||||
|
||||
# train |
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \ |
||||
--nproc_per_node 4 train.py \ |
||||
--config configs/vit_1d_tp2.py \ |
||||
--master_port 29598 | tee ./out 2>&1 |
@ -0,0 +1,27 @@
|
||||
set -xe |
||||
pip install -r requirements.txt |
||||
|
||||
export BS=8 |
||||
export MEMCAP=0 |
||||
export GPUNUM=1 |
||||
|
||||
for BS in 8 32 128 |
||||
do |
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" |
||||
do |
||||
for GPUNUM in 1 4 |
||||
do |
||||
|
||||
MODEL_PATH="google/vit-base-patch16-224" |
||||
torchrun \ |
||||
--standalone \ |
||||
--nproc_per_node ${GPUNUM} \ |
||||
vit_benchmark.py \ |
||||
--model_name_or_path ${MODEL_PATH} \ |
||||
--mem_cap ${MEMCAP} \ |
||||
--plugin ${PLUGIN} \ |
||||
--batch_size ${BS} |
||||
|
||||
done |
||||
done |
||||
done |
@ -0,0 +1,44 @@
|
||||
set -xe |
||||
pip install -r requirements.txt |
||||
|
||||
# model name or path |
||||
MODEL="google/vit-base-patch16-224" |
||||
|
||||
# path for saving model |
||||
OUTPUT_PATH="./output_model.bin" |
||||
|
||||
# plugin(training strategy) |
||||
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" |
||||
PLUGIN="gemini" |
||||
|
||||
# number of gpus to use |
||||
GPUNUM=4 |
||||
|
||||
# batch size per gpu |
||||
BS=16 |
||||
|
||||
# learning rate |
||||
LR="2e-4" |
||||
|
||||
# number of epoch |
||||
EPOCH=3 |
||||
|
||||
# weight decay |
||||
WEIGHT_DECAY=0.05 |
||||
|
||||
# ratio of warmup steps |
||||
WARMUP_RATIO=0.3 |
||||
|
||||
# run the script for demo |
||||
torchrun \ |
||||
--standalone \ |
||||
--nproc_per_node ${GPUNUM} \ |
||||
vit_train_demo.py \ |
||||
--model_name_or_path ${MODEL} \ |
||||
--output_path ${OUTPUT_PATH} \ |
||||
--plugin ${PLUGIN} \ |
||||
--batch_size ${BS} \ |
||||
--num_epoch ${EPOCH} \ |
||||
--learning_rate ${LR} \ |
||||
--weight_decay ${WEIGHT_DECAY} \ |
||||
--warmup_ratio ${WARMUP_RATIO} |
@ -1,9 +1,19 @@
|
||||
export OMP_NUM_THREADS=4 |
||||
|
||||
set -xe |
||||
pip install -r requirements.txt |
||||
|
||||
# train |
||||
colossalai run \ |
||||
--nproc_per_node 4 train.py \ |
||||
--config configs/vit_1d_tp2_ci.py \ |
||||
--dummy_data |
||||
BS=8 |
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" |
||||
do |
||||
for GPUNUM in 1 4 |
||||
do |
||||
|
||||
torchrun \ |
||||
--standalone \ |
||||
--nproc_per_node ${GPUNUM} \ |
||||
vit_benchmark.py \ |
||||
--model_name_or_path "google/vit-base-patch16-224" \ |
||||
--plugin ${PLUGIN} \ |
||||
--batch_size ${BS} |
||||
|
||||
done |
||||
done |
||||
|
@ -1,160 +0,0 @@
|
||||
import os |
||||
import random |
||||
|
||||
import numpy as np |
||||
import pytest |
||||
import torch |
||||
from torch.nn.parallel import DistributedDataParallel as DDP |
||||
from vit import get_training_components |
||||
|
||||
import colossalai |
||||
from colossalai.context import ParallelMode |
||||
from colossalai.context.parallel_mode import ParallelMode |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.nn.parallel.data_parallel import ColoDDP |
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec |
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn |
||||
from colossalai.utils.cuda import get_current_device |
||||
from colossalai.zero import ColoInitContext |
||||
|
||||
|
||||
def set_seed(seed): |
||||
random.seed(seed) |
||||
os.environ['PYTHONHASHSEED'] = str(seed) |
||||
np.random.seed(seed) |
||||
torch.manual_seed(seed) |
||||
torch.cuda.manual_seed(seed) |
||||
torch.backends.cudnn.deterministic = True |
||||
|
||||
|
||||
def tensor_equal(A, B): |
||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) |
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): |
||||
assert tensor.ndim == shard.ndim |
||||
if tensor.shape == shard.shape: |
||||
return tensor_equal(tensor, shard) |
||||
else: |
||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) |
||||
if dims_not_eq.numel() == 1: |
||||
# 1D shard |
||||
dim = dims_not_eq.item() |
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) |
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) |
||||
else: |
||||
raise |
||||
|
||||
|
||||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. |
||||
# But for other layers, it's 1d_col split. |
||||
# Layernorm is not supported for now. |
||||
# patch_embeddings.projection has nn.Conv2d |
||||
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 |
||||
def init_1d_row_for_linear_weight_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
# Similarly, it's col split for Linear but row split for others. |
||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if ('weight' in n |
||||
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
def check_param_equal(model, torch_model): |
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): |
||||
assert tensor_shard_equal(torch_p, p) |
||||
|
||||
|
||||
def check_grad_equal(model, torch_model): |
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): |
||||
if (torch_p.grad.shape == p.grad.shape): |
||||
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True |
||||
else: |
||||
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) |
||||
dim = dims_not_eq.item() |
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) |
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) |
||||
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True |
||||
|
||||
|
||||
def run_vit(init_spec_func, use_ddp): |
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() |
||||
with ColoInitContext(device=get_current_device()): |
||||
model = model_builder() |
||||
model = model.cuda() |
||||
torch_model = model_builder().cuda() |
||||
if use_ddp: |
||||
model = ColoDDP(model) |
||||
torch_model = DDP(torch_model, |
||||
device_ids=[gpc.get_global_rank()], |
||||
process_group=gpc.get_group(ParallelMode.DATA)) |
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()): |
||||
torch_p.data.copy_(p) |
||||
|
||||
world_size = torch.distributed.get_world_size() |
||||
init_spec_func(model, world_size) |
||||
|
||||
check_param_equal(model, torch_model) |
||||
model.train() |
||||
torch_model.train() |
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA)) |
||||
|
||||
optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
||||
|
||||
for i, image_dict in enumerate(train_dataloader): |
||||
if use_ddp: |
||||
model.zero_grad() |
||||
else: |
||||
optimizer.zero_grad() |
||||
logits = model(image_dict['pixel_values']) |
||||
torch_logits = torch_model(image_dict['pixel_values']) |
||||
assert tensor_equal(torch_logits.logits, logits.logits) |
||||
loss = criterion(logits.logits, image_dict['label']) |
||||
torch_loss = criterion(torch_logits.logits, image_dict['label']) |
||||
if use_ddp: |
||||
model.backward(loss) |
||||
else: |
||||
loss.backward() |
||||
torch_loss.backward() |
||||
check_grad_equal(model, torch_model) |
||||
optimizer.step() |
||||
torch_optimizer.step() |
||||
check_param_equal(model, torch_model) |
||||
break |
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp): |
||||
if use_ddp and world_size == 1: |
||||
return |
||||
tp_world_size = world_size // 2 if use_ddp else world_size |
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) |
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
run_vit(init_1d_row_for_linear_weight_spec, use_ddp) |
||||
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
@pytest.mark.parametrize('world_size', [1, 4]) |
||||
@pytest.mark.parametrize('use_ddp', [False, True]) |
||||
@rerun_if_address_is_in_use() |
||||
def test_vit(world_size, use_ddp): |
||||
spawn(run_dist, world_size, use_ddp=use_ddp) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_vit(1, False) |
@ -1,174 +0,0 @@
|
||||
import os |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from timm.models.vision_transformer import _create_vision_transformer |
||||
from titans.dataloader.imagenet import build_dali_imagenet |
||||
from tqdm import tqdm |
||||
from vit import DummyDataLoader |
||||
|
||||
import colossalai |
||||
from colossalai.core import global_context as gpc |
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger |
||||
from colossalai.nn import CrossEntropyLoss |
||||
from colossalai.nn._ops import * |
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.nn.parallel.data_parallel import ColoDDP |
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec |
||||
from colossalai.utils import get_current_device |
||||
from colossalai.zero import ColoInitContext |
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
# Similarly, it's col split for Linear but row split for others. |
||||
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): |
||||
pg = ProcessGroup(tp_degree=world_size) |
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
||||
with DistSpecManager.no_grad(): |
||||
for n, p in model.named_parameters(): |
||||
if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n |
||||
and 'patch_embed.proj.bias' not in n): |
||||
p.set_process_group(pg) |
||||
p.set_tensor_spec(*spec) |
||||
|
||||
|
||||
def init_spec_func(model, tp_type): |
||||
world_size = torch.distributed.get_world_size() |
||||
if tp_type == 'row': |
||||
init_1d_row_for_linear_weight_spec(model, world_size) |
||||
elif tp_type == 'col': |
||||
init_1d_col_for_linear_weight_bias_spec(model, world_size) |
||||
else: |
||||
raise NotImplemented |
||||
|
||||
|
||||
def train_imagenet(): |
||||
|
||||
parser = colossalai.get_default_parser() |
||||
parser.add_argument('--resume_from', default=False, action='store_true') |
||||
parser.add_argument('--dummy_data', default=False, action='store_true') |
||||
|
||||
args = parser.parse_args() |
||||
colossalai.launch_from_torch(config=args.config) |
||||
use_ddp = gpc.config.USE_DDP |
||||
|
||||
disable_existing_loggers() |
||||
|
||||
logger = get_dist_logger() |
||||
if hasattr(gpc.config, 'LOG_PATH'): |
||||
if gpc.get_global_rank() == 0: |
||||
log_path = gpc.config.LOG_PATH |
||||
if not os.path.exists(log_path): |
||||
os.mkdir(log_path) |
||||
logger.log_to_file(log_path) |
||||
|
||||
logger.info('Build data loader', ranks=[0]) |
||||
if not args.dummy_data: |
||||
root = os.environ['DATA'] |
||||
train_dataloader, test_dataloader = build_dali_imagenet(root, |
||||
train_batch_size=gpc.config.BATCH_SIZE, |
||||
test_batch_size=gpc.config.BATCH_SIZE) |
||||
else: |
||||
train_dataloader = DummyDataLoader(length=10, |
||||
batch_size=gpc.config.BATCH_SIZE, |
||||
category=gpc.config.NUM_CLASSES, |
||||
image_size=gpc.config.IMG_SIZE, |
||||
return_dict=False) |
||||
test_dataloader = DummyDataLoader(length=5, |
||||
batch_size=gpc.config.BATCH_SIZE, |
||||
category=gpc.config.NUM_CLASSES, |
||||
image_size=gpc.config.IMG_SIZE, |
||||
return_dict=False) |
||||
|
||||
logger.info('Build model', ranks=[0]) |
||||
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE, |
||||
patch_size=gpc.config.PATCH_SIZE, |
||||
embed_dim=gpc.config.HIDDEN_SIZE, |
||||
depth=gpc.config.DEPTH, |
||||
num_heads=gpc.config.NUM_HEADS, |
||||
mlp_ratio=gpc.config.MLP_RATIO, |
||||
num_classes=gpc.config.NUM_CLASSES, |
||||
drop_rate=0.1, |
||||
attn_drop_rate=0.1, |
||||
weight_init='jax') |
||||
|
||||
with ColoInitContext(device=get_current_device()): |
||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) |
||||
init_spec_func(model, gpc.config.TP_TYPE) |
||||
|
||||
world_size = torch.distributed.get_world_size() |
||||
model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) |
||||
logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) |
||||
optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) |
||||
|
||||
criterion = CrossEntropyLoss() |
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, |
||||
total_steps=gpc.config.NUM_EPOCHS, |
||||
warmup_steps=gpc.config.WARMUP_EPOCHS) |
||||
|
||||
start_epoch = 0 |
||||
if args.resume_from: |
||||
load_model = torch.load(args.resume_from + '_model.pth') |
||||
start_epoch = load_model['epoch'] |
||||
model.load_state_dict(load_model['model']) |
||||
load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) |
||||
optimizer.load_state_dict(load_optim['optim']) |
||||
|
||||
for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): |
||||
model.train() |
||||
for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): |
||||
x, y = x.cuda(), y.cuda() |
||||
output = model(x) |
||||
loss = criterion(output, y) |
||||
loss = loss / gpc.config.gradient_accumulation |
||||
if use_ddp: |
||||
model.backward(loss) |
||||
else: |
||||
loss.backward() |
||||
if (index + 1) % gpc.config.gradient_accumulation == 0: |
||||
optimizer.step() |
||||
if use_ddp: |
||||
model.zero_grad() |
||||
else: |
||||
optimizer.zero_grad() |
||||
|
||||
logger.info( |
||||
f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", |
||||
ranks=[0]) |
||||
|
||||
model.eval() |
||||
test_loss = 0 |
||||
correct = 0 |
||||
test_sum = 0 |
||||
with torch.no_grad(): |
||||
for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): |
||||
x, y = x.cuda(), y.cuda() |
||||
output = model(x) |
||||
test_loss += F.cross_entropy(output, y, reduction='sum').item() |
||||
pred = output.argmax(dim=1, keepdim=True) |
||||
correct += pred.eq(y.view_as(pred)).sum().item() |
||||
test_sum += y.size(0) |
||||
|
||||
test_loss /= test_sum |
||||
logger.info( |
||||
f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", |
||||
ranks=[0]) |
||||
|
||||
lr_scheduler.step() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
train_imagenet() |
@ -1,95 +0,0 @@
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from transformers import ViTConfig, ViTForImageClassification |
||||
|
||||
from colossalai.utils.cuda import get_current_device |
||||
|
||||
|
||||
class DummyDataGenerator(ABC): |
||||
|
||||
def __init__(self, length=10): |
||||
self.length = length |
||||
|
||||
@abstractmethod |
||||
def generate(self): |
||||
pass |
||||
|
||||
def __iter__(self): |
||||
self.step = 0 |
||||
return self |
||||
|
||||
def __next__(self): |
||||
if self.step < self.length: |
||||
self.step += 1 |
||||
return self.generate() |
||||
else: |
||||
raise StopIteration |
||||
|
||||
def __len__(self): |
||||
return self.length |
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator): |
||||
|
||||
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True): |
||||
super().__init__(length) |
||||
self.batch_size = batch_size |
||||
self.channel = channel |
||||
self.category = category |
||||
self.image_size = image_size |
||||
self.return_dict = return_dict |
||||
|
||||
def generate(self): |
||||
image_dict = {} |
||||
image_dict['pixel_values'] = torch.rand( |
||||
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1 |
||||
image_dict['label'] = torch.randint(self.category, (self.batch_size,), |
||||
dtype=torch.int64, |
||||
device=get_current_device()) |
||||
if not self.return_dict: |
||||
return image_dict['pixel_values'], image_dict['label'] |
||||
return image_dict |
||||
|
||||
|
||||
class ViTCVModel(nn.Module): |
||||
|
||||
def __init__(self, |
||||
hidden_size=768, |
||||
num_hidden_layers=12, |
||||
num_attention_heads=12, |
||||
image_size=224, |
||||
patch_size=16, |
||||
num_channels=3, |
||||
num_labels=8, |
||||
checkpoint=False): |
||||
super().__init__() |
||||
self.checkpoint = checkpoint |
||||
self.model = ViTForImageClassification( |
||||
ViTConfig(hidden_size=hidden_size, |
||||
num_hidden_layers=num_hidden_layers, |
||||
num_attention_heads=num_attention_heads, |
||||
image_size=image_size, |
||||
patch_size=patch_size, |
||||
num_channels=num_channels, |
||||
num_labels=num_labels)) |
||||
if checkpoint: |
||||
self.model.gradient_checkpointing_enable() |
||||
|
||||
def forward(self, pixel_values): |
||||
return self.model(pixel_values=pixel_values) |
||||
|
||||
|
||||
def vit_base_s(checkpoint=True): |
||||
return ViTCVModel(checkpoint=checkpoint) |
||||
|
||||
|
||||
def vit_base_micro(checkpoint=True): |
||||
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) |
||||
|
||||
|
||||
def get_training_components(): |
||||
trainloader = DummyDataLoader() |
||||
testloader = DummyDataLoader() |
||||
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy |
@ -0,0 +1,129 @@
|
||||
import time |
||||
|
||||
import torch |
||||
import transformers |
||||
from transformers import ViTConfig, ViTForImageClassification |
||||
import tqdm |
||||
|
||||
import colossalai |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger |
||||
from colossalai.utils import get_current_device |
||||
from colossalai.booster import Booster |
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
||||
from colossalai.cluster import DistCoordinator |
||||
|
||||
from args import parse_benchmark_args |
||||
|
||||
def format_num(num: int, bytes=False): |
||||
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" |
||||
factor = 1024 if bytes else 1000 |
||||
suffix = "B" if bytes else "" |
||||
for unit in ["", " K", " M", " G", " T", " P"]: |
||||
if num < factor: |
||||
return f"{num:.2f}{unit}{suffix}" |
||||
num /= factor |
||||
|
||||
|
||||
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): |
||||
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float) |
||||
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64) |
||||
return pixel_values, labels |
||||
|
||||
|
||||
def colo_memory_cap(size_in_GB): |
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device |
||||
cuda_capacity = colo_device_memory_capacity(get_current_device()) |
||||
if size_in_GB * (1024**3) < cuda_capacity: |
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) |
||||
print(f"Limiting GPU memory usage to {size_in_GB} GB") |
||||
|
||||
|
||||
def main(): |
||||
|
||||
args = parse_benchmark_args() |
||||
|
||||
# Launch ColossalAI |
||||
colossalai.launch_from_torch(config={}, seed=args.seed) |
||||
coordinator = DistCoordinator() |
||||
world_size = coordinator.world_size |
||||
|
||||
# Manage loggers |
||||
disable_existing_loggers() |
||||
logger = get_dist_logger() |
||||
if coordinator.is_master(): |
||||
transformers.utils.logging.set_verbosity_info() |
||||
else: |
||||
transformers.utils.logging.set_verbosity_error() |
||||
|
||||
# Whether to set limit on memory capacity |
||||
if args.mem_cap > 0: |
||||
colo_memory_cap(args.mem_cap) |
||||
|
||||
# Build ViT model |
||||
config = ViTConfig.from_pretrained(args.model_name_or_path) |
||||
model = ViTForImageClassification(config) |
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) |
||||
|
||||
# Enable gradient checkpointing |
||||
model.gradient_checkpointing_enable() |
||||
|
||||
# Set plugin |
||||
booster_kwargs = {} |
||||
if args.plugin == 'torch_ddp_fp16': |
||||
booster_kwargs['mixed_precision'] = 'fp16' |
||||
if args.plugin.startswith('torch_ddp'): |
||||
plugin = TorchDDPPlugin() |
||||
elif args.plugin == 'gemini': |
||||
plugin = GeminiPlugin(device=get_current_device(), |
||||
placement_policy='cpu', |
||||
pin_memory=True, |
||||
strict_ddp_mode=True, |
||||
initial_scale=2**5) |
||||
elif args.plugin == 'low_level_zero': |
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5) |
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) |
||||
|
||||
# Set optimizer |
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) |
||||
|
||||
# Set booster |
||||
booster = Booster(plugin=plugin, **booster_kwargs) |
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer) |
||||
|
||||
|
||||
# Start training. |
||||
logger.info(f"Start testing", ranks=[0]) |
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) |
||||
|
||||
torch.cuda.synchronize() |
||||
model.train() |
||||
start_time = time.time() |
||||
|
||||
for _ in range(args.max_train_steps): |
||||
|
||||
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) |
||||
optimizer.zero_grad() |
||||
outputs = model(pixel_values=pixel_values, labels=labels) |
||||
loss = outputs['loss'] |
||||
booster.backward(loss, optimizer) |
||||
optimizer.step() |
||||
|
||||
torch.cuda.synchronize() |
||||
progress_bar.update(1) |
||||
|
||||
# Compute Statistics |
||||
end_time = time.time() |
||||
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) |
||||
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) |
||||
|
||||
logger.info(f"Testing finished, " |
||||
f"batch size per gpu: {args.batch_size}, " |
||||
f"plugin: {args.plugin}, " |
||||
f"throughput: {throughput}, " |
||||
f"maximum memory usage per gpu: {max_mem}.", |
||||
ranks=[0]) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
@ -0,0 +1,177 @@
|
||||
import torch |
||||
import torch.distributed as dist |
||||
import transformers |
||||
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor |
||||
from tqdm import tqdm |
||||
|
||||
import colossalai |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger |
||||
from colossalai.utils import get_current_device |
||||
from colossalai.booster import Booster |
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
||||
from colossalai.cluster import DistCoordinator |
||||
|
||||
from args import parse_demo_args |
||||
from data import BeansDataset, beans_collator |
||||
|
||||
|
||||
def move_to_cuda(batch, device): |
||||
return {k: v.to(device) for k, v in batch.items()} |
||||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): |
||||
|
||||
torch.cuda.synchronize() |
||||
model.train() |
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: |
||||
|
||||
for batch in pbar: |
||||
|
||||
# Foward |
||||
optimizer.zero_grad() |
||||
batch = move_to_cuda(batch, torch.cuda.current_device()) |
||||
outputs = model(**batch) |
||||
loss = outputs['loss'] |
||||
|
||||
# Backward |
||||
booster.backward(loss, optimizer) |
||||
optimizer.step() |
||||
lr_scheduler.step() |
||||
|
||||
# Print batch loss |
||||
pbar.set_postfix({'loss': loss.item()}) |
||||
|
||||
|
||||
@torch.no_grad() |
||||
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): |
||||
|
||||
model.eval() |
||||
accum_loss = torch.zeros(1, device=get_current_device()) |
||||
total_num = torch.zeros(1, device=get_current_device()) |
||||
accum_correct = torch.zeros(1, device=get_current_device()) |
||||
|
||||
for batch in eval_dataloader: |
||||
batch = move_to_cuda(batch, torch.cuda.current_device()) |
||||
outputs = model(**batch) |
||||
val_loss, logits = outputs[:2] |
||||
accum_loss += (val_loss / len(eval_dataloader)) |
||||
if num_labels > 1: |
||||
preds = torch.argmax(logits, dim=1) |
||||
elif num_labels == 1: |
||||
preds = logits.squeeze() |
||||
|
||||
labels = batch["labels"] |
||||
total_num += batch["labels"].shape[0] |
||||
accum_correct += (torch.sum(preds == labels)) |
||||
|
||||
dist.all_reduce(accum_loss) |
||||
dist.all_reduce(total_num) |
||||
dist.all_reduce(accum_correct) |
||||
avg_loss = "{:.4f}".format(accum_loss.item()) |
||||
accuracy = "{:.4f}".format(accum_correct.item() / total_num.item()) |
||||
if coordinator.is_master(): |
||||
print(f"Evaluation result for epoch {epoch + 1}: \ |
||||
average_loss={avg_loss}, \ |
||||
accuracy={accuracy}.") |
||||
|
||||
|
||||
|
||||
|
||||
def main(): |
||||
|
||||
args = parse_demo_args() |
||||
|
||||
# Launch ColossalAI |
||||
colossalai.launch_from_torch(config={}, seed=args.seed) |
||||
coordinator = DistCoordinator() |
||||
world_size = coordinator.world_size |
||||
|
||||
# Manage loggers |
||||
disable_existing_loggers() |
||||
logger = get_dist_logger() |
||||
if coordinator.is_master(): |
||||
transformers.utils.logging.set_verbosity_info() |
||||
else: |
||||
transformers.utils.logging.set_verbosity_error() |
||||
|
||||
# Prepare Dataset |
||||
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) |
||||
train_dataset = BeansDataset(image_processor, split='train') |
||||
eval_dataset = BeansDataset(image_processor, split='validation') |
||||
|
||||
|
||||
# Load pretrained ViT model |
||||
config = ViTConfig.from_pretrained(args.model_name_or_path) |
||||
config.num_labels = train_dataset.num_labels |
||||
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} |
||||
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} |
||||
model = ViTForImageClassification.from_pretrained(args.model_name_or_path, |
||||
config=config, |
||||
ignore_mismatched_sizes=True) |
||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) |
||||
|
||||
# Enable gradient checkpointing |
||||
model.gradient_checkpointing_enable() |
||||
|
||||
# Set plugin |
||||
booster_kwargs = {} |
||||
if args.plugin == 'torch_ddp_fp16': |
||||
booster_kwargs['mixed_precision'] = 'fp16' |
||||
if args.plugin.startswith('torch_ddp'): |
||||
plugin = TorchDDPPlugin() |
||||
elif args.plugin == 'gemini': |
||||
plugin = GeminiPlugin(device=get_current_device(), |
||||
placement_policy='cpu', |
||||
pin_memory=True, |
||||
strict_ddp_mode=True, |
||||
initial_scale=2**5) |
||||
elif args.plugin == 'low_level_zero': |
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5) |
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) |
||||
|
||||
# Prepare dataloader |
||||
train_dataloader = plugin.prepare_dataloader(train_dataset, |
||||
batch_size=args.batch_size, |
||||
shuffle=True, |
||||
drop_last=True, |
||||
collate_fn=beans_collator) |
||||
eval_dataloader = plugin.prepare_dataloader(eval_dataset, |
||||
batch_size=args.batch_size, |
||||
shuffle=True, |
||||
drop_last=True, |
||||
collate_fn=beans_collator) |
||||
|
||||
# Set optimizer |
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) |
||||
|
||||
# Set lr scheduler |
||||
total_steps = len(train_dataloader) * args.num_epoch |
||||
num_warmup_steps = int(args.warmup_ratio * total_steps) |
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, |
||||
total_steps=(len(train_dataloader) * args.num_epoch), |
||||
warmup_steps=num_warmup_steps) |
||||
|
||||
# Set booster |
||||
booster = Booster(plugin=plugin, **booster_kwargs) |
||||
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, |
||||
optimizer=optimizer, |
||||
dataloader=train_dataloader, |
||||
lr_scheduler=lr_scheduler) |
||||
|
||||
# Finetuning |
||||
logger.info(f"Start finetuning", ranks=[0]) |
||||
for epoch in range(args.num_epoch): |
||||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) |
||||
evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator) |
||||
logger.info(f"Finish finetuning", ranks=[0]) |
||||
|
||||
# Save the finetuned model |
||||
booster.save_model(model, args.output_path) |
||||
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
Loading…
Reference in new issue