mirror of https://github.com/hpcaitech/ColossalAI
[example] update vit example for hybrid parallel plugin (#4641)
* update vit example for hybrid plugin * reset tp/pp size * fix dataloader iteration bug * update optimizer passing in evaluation/add grad_accum * change criterion * wrap tqdm * change grad_accum to grad_checkpoint * fix pbarpull/4659/head
parent
660eed9124
commit
295b38fecf
|
@ -884,6 +884,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
|
|||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions is not None:
|
||||
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
|
||||
output_attentions = None
|
||||
if output_hidden_states is not None:
|
||||
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = None
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
|
||||
|
||||
In our example, we are using pretrained weights of ViT loaded from HuggingFace.
|
||||
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
|
||||
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel).
|
||||
|
||||
## Run Demo
|
||||
|
||||
|
@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script:
|
|||
```bash
|
||||
bash run_benchmark.sh
|
||||
```
|
||||
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
|
||||
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
|
||||
|
|
|
@ -1,124 +1,82 @@
|
|||
from colossalai import get_default_parser
|
||||
|
||||
|
||||
def parse_demo_args():
|
||||
|
||||
parser = get_default_parser()
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
default="google/vit-base-patch16-224",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./output_model.bin",
|
||||
help="The path of your saved model after finetuning."
|
||||
)
|
||||
parser.add_argument("--model_name_or_path",
|
||||
type=str,
|
||||
default="google/vit-base-patch16-224",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.")
|
||||
parser.add_argument("--output_path",
|
||||
type=str,
|
||||
default="./output_model",
|
||||
help="The path of your saved model after finetuning.")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_epoch",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of epochs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size (per dp group) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_ratio",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Ratio of warmup steps against total training steps."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Weight decay to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="A seed for reproducible training."
|
||||
help=
|
||||
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
|
||||
)
|
||||
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.")
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size (per dp group) for the training dataloader.")
|
||||
parser.add_argument("--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.")
|
||||
parser.add_argument("--pp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.")
|
||||
parser.add_argument("--learning_rate",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.")
|
||||
parser.add_argument("--warmup_ratio",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Ratio of warmup steps against total training steps.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
|
||||
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_benchmark_args():
|
||||
|
||||
parser = get_default_parser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
default="google/vit-base-patch16-224",
|
||||
help="Path to a pretrained model or model identifier from huggingface.co/models."
|
||||
)
|
||||
parser.add_argument("--model_name_or_path",
|
||||
type=str,
|
||||
default="google/vit-base-patch16-224",
|
||||
help="Path to a pretrained model or model identifier from huggingface.co/models.")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per dp group) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_labels",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of labels for classification."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Weight decay to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Total number of training steps to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="A seed for reproducible training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem_cap",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Limit on the usage of space for each GPU (in GB)."
|
||||
help=
|
||||
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per dp group) for the training dataloader.")
|
||||
parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.")
|
||||
parser.add_argument("--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
|
||||
parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
return args
|
||||
|
|
|
@ -1,32 +1,38 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class BeansDataset(Dataset):
|
||||
|
||||
def __init__(self, image_processor, split='train'):
|
||||
|
||||
def __init__(self, image_processor, tp_size=1, split='train'):
|
||||
|
||||
super().__init__()
|
||||
self.image_processor = image_processor
|
||||
self.ds = load_dataset('beans')[split]
|
||||
self.label_names = self.ds.features['labels'].names
|
||||
while len(self.label_names) % tp_size != 0:
|
||||
# ensure that the number of labels is multiple of tp_size
|
||||
self.label_names.append(f"pad_label_{len(self.label_names)}")
|
||||
self.num_labels = len(self.label_names)
|
||||
self.inputs = []
|
||||
for example in self.ds:
|
||||
self.inputs.append(self.process_example(example))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.inputs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.inputs[idx]
|
||||
|
||||
|
||||
def process_example(self, example):
|
||||
input = self.image_processor(example['image'], return_tensors='pt')
|
||||
input['labels'] = example['labels']
|
||||
return input
|
||||
|
||||
|
||||
|
||||
def beans_collator(batch):
|
||||
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
|
||||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
|
||||
return {
|
||||
'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
|
||||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
|
||||
}
|
||||
|
|
|
@ -5,23 +5,20 @@ export BS=8
|
|||
export MEMCAP=0
|
||||
export GPUNUM=1
|
||||
|
||||
for BS in 8 32 128
|
||||
for BS in 8 32
|
||||
do
|
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||
do
|
||||
for GPUNUM in 1 4
|
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
|
||||
do
|
||||
|
||||
MODEL_PATH="google/vit-base-patch16-224"
|
||||
torchrun \
|
||||
--standalone \
|
||||
--nproc_per_node ${GPUNUM} \
|
||||
--nproc_per_node 4 \
|
||||
vit_benchmark.py \
|
||||
--model_name_or_path ${MODEL_PATH} \
|
||||
--mem_cap ${MEMCAP} \
|
||||
--plugin ${PLUGIN} \
|
||||
--batch_size ${BS}
|
||||
|
||||
done
|
||||
|
||||
done
|
||||
done
|
||||
|
|
|
@ -5,16 +5,21 @@ pip install -r requirements.txt
|
|||
MODEL="google/vit-base-patch16-224"
|
||||
|
||||
# path for saving model
|
||||
OUTPUT_PATH="./output_model.bin"
|
||||
OUTPUT_PATH="./output_model"
|
||||
|
||||
# plugin(training strategy)
|
||||
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
|
||||
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel"
|
||||
PLUGIN="gemini"
|
||||
#PLUGIN="hybrid_parallel"
|
||||
|
||||
# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel"
|
||||
TP_SIZE=2
|
||||
PP_SIZE=2
|
||||
|
||||
# number of gpus to use
|
||||
GPUNUM=4
|
||||
|
||||
# batch size per gpu
|
||||
# batch size per data parallel group
|
||||
BS=16
|
||||
|
||||
# learning rate
|
||||
|
@ -38,6 +43,8 @@ torchrun \
|
|||
--output_path ${OUTPUT_PATH} \
|
||||
--plugin ${PLUGIN} \
|
||||
--batch_size ${BS} \
|
||||
--tp_size ${TP_SIZE} \
|
||||
--pp_size ${PP_SIZE} \
|
||||
--num_epoch ${EPOCH} \
|
||||
--learning_rate ${LR} \
|
||||
--weight_decay ${WEIGHT_DECAY} \
|
||||
|
|
|
@ -2,18 +2,15 @@ set -xe
|
|||
pip install -r requirements.txt
|
||||
|
||||
BS=8
|
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||
do
|
||||
for GPUNUM in 1 4
|
||||
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
|
||||
do
|
||||
|
||||
torchrun \
|
||||
--standalone \
|
||||
--nproc_per_node ${GPUNUM} \
|
||||
--nproc_per_node 4 \
|
||||
vit_benchmark.py \
|
||||
--model_name_or_path "google/vit-base-patch16-224" \
|
||||
--plugin ${PLUGIN} \
|
||||
--batch_size ${BS}
|
||||
|
||||
done
|
||||
done
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
from args import parse_benchmark_args
|
||||
from tqdm import tqdm
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
|
|||
num /= factor
|
||||
|
||||
|
||||
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
|
||||
pixel_values = torch.randn(batch_size,
|
||||
num_channels,
|
||||
height,
|
||||
|
@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
|
|||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float)
|
||||
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
return pixel_values, labels
|
||||
return dict(pixel_values=pixel_values, labels=labels)
|
||||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
|
@ -70,7 +70,8 @@ def main():
|
|||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
|
@ -82,34 +83,57 @@ def main():
|
|||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
plugin = HybridParallelPlugin(tp_size=2,
|
||||
pp_size=2,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision='fp16',
|
||||
initial_scale=1)
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||||
|
||||
# Set criterion (loss function)
|
||||
def criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
|
||||
|
||||
# Start training.
|
||||
logger.info(f"Start testing", ranks=[0])
|
||||
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(args.max_train_steps):
|
||||
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
|
||||
for _ in pbar:
|
||||
optimizer.zero_grad()
|
||||
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
|
||||
|
||||
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(pixel_values=pixel_values, labels=labels)
|
||||
loss = outputs['loss']
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
|
||||
# run pipeline forward backward
|
||||
batch = iter([batch])
|
||||
outputs = booster.execute_pipeline(batch,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
else:
|
||||
outputs = model(**batch)
|
||||
loss = criterion(outputs, None)
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
progress_bar.update(1)
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compute Statistics
|
||||
end_time = time.time()
|
||||
|
@ -124,6 +148,8 @@ def main():
|
|||
f"maximum memory usage per gpu: {max_mem}.",
|
||||
ranks=[0])
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,70 +1,111 @@
|
|||
from typing import Any, Callable, Iterator
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from args import parse_demo_args
|
||||
from data import BeansDataset, beans_collator
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||
def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
data_iter: Iterator, booster: Booster):
|
||||
if optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
|
||||
# run pipeline forward backward when enabling pp in hybrid parallel plugin
|
||||
output_dict = booster.execute_pipeline(data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
loss, outputs = output_dict['loss'], output_dict['outputs']
|
||||
else:
|
||||
batch = next(data_iter)
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
outputs = model(**batch)
|
||||
loss = criterion(outputs, None)
|
||||
if optimizer is not None:
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
return loss, outputs
|
||||
|
||||
|
||||
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_steps = len(dataloader)
|
||||
data_iter = iter(dataloader)
|
||||
enable_pbar = coordinator.is_master()
|
||||
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
|
||||
# when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
|
||||
tp_rank = dist.get_rank(booster.plugin.tp_group)
|
||||
dp_rank = dist.get_rank(booster.plugin.dp_group)
|
||||
enable_pbar = tp_rank == 0 and dp_rank == 0 \
|
||||
and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
model.train()
|
||||
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||
|
||||
for batch in pbar:
|
||||
|
||||
# Foward
|
||||
optimizer.zero_grad()
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar:
|
||||
for _ in pbar:
|
||||
loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Print batch loss
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
if enable_pbar:
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
|
||||
def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor],
|
||||
eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.eval()
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
total_num = torch.zeros(1, device=get_current_device())
|
||||
accum_correct = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
|
||||
total_num = torch.zeros(1, device=torch.cuda.current_device())
|
||||
accum_correct = torch.zeros(1, device=torch.cuda.current_device())
|
||||
|
||||
for batch in eval_dataloader:
|
||||
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||
outputs = model(**batch)
|
||||
val_loss, logits = outputs[:2]
|
||||
accum_loss += (val_loss / len(eval_dataloader))
|
||||
if num_labels > 1:
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
elif num_labels == 1:
|
||||
preds = logits.squeeze()
|
||||
loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster)
|
||||
|
||||
labels = batch["labels"]
|
||||
total_num += batch["labels"].shape[0]
|
||||
accum_correct += (torch.sum(preds == labels))
|
||||
to_accum = True
|
||||
if isinstance(booster.plugin, HybridParallelPlugin):
|
||||
# when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0
|
||||
to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0)
|
||||
if booster.plugin.pp_size > 1:
|
||||
to_accum = to_accum and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
if to_accum:
|
||||
accum_loss += (loss / len(eval_dataloader))
|
||||
logits = outputs["logits"]
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
|
||||
labels = batch["labels"]
|
||||
total_num += batch["labels"].shape[0]
|
||||
accum_correct += (torch.sum(preds == labels))
|
||||
|
||||
dist.all_reduce(accum_loss)
|
||||
dist.all_reduce(total_num)
|
||||
|
@ -94,14 +135,20 @@ def main():
|
|||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Reset tp_size and pp_size to 1 if not using hybrid parallel.
|
||||
if args.plugin != 'hybrid_parallel':
|
||||
args.tp_size = 1
|
||||
args.pp_size = 1
|
||||
|
||||
# Prepare Dataset
|
||||
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
|
||||
train_dataset = BeansDataset(image_processor, split='train')
|
||||
eval_dataset = BeansDataset(image_processor, split='validation')
|
||||
train_dataset = BeansDataset(image_processor, args.tp_size, split='train')
|
||||
eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation')
|
||||
num_labels = train_dataset.num_labels
|
||||
|
||||
# Load pretrained ViT model
|
||||
config = ViTConfig.from_pretrained(args.model_name_or_path)
|
||||
config.num_labels = train_dataset.num_labels
|
||||
config.num_labels = num_labels
|
||||
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
|
||||
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
|
||||
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
|
||||
|
@ -110,7 +157,8 @@ def main():
|
|||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
|
@ -122,6 +170,16 @@ def main():
|
|||
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == 'hybrid_parallel':
|
||||
plugin = HybridParallelPlugin(tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
num_microbatches=None,
|
||||
microbatch_size=1,
|
||||
enable_all_optimization=True,
|
||||
precision='fp16',
|
||||
initial_scale=1)
|
||||
else:
|
||||
raise ValueError(f"Plugin with name {args.plugin} is not supported!")
|
||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||
|
||||
# Prepare dataloader
|
||||
|
@ -139,6 +197,10 @@ def main():
|
|||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
||||
|
||||
# Set criterion (loss function)
|
||||
def criterion(outputs, inputs):
|
||||
return outputs.loss
|
||||
|
||||
# Set lr scheduler
|
||||
total_steps = len(train_dataloader) * args.num_epoch
|
||||
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||
|
@ -148,20 +210,21 @@ def main():
|
|||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# Finetuning
|
||||
logger.info(f"Start finetuning", ranks=[0])
|
||||
for epoch in range(args.num_epoch):
|
||||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
|
||||
train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator)
|
||||
logger.info(f"Finish finetuning", ranks=[0])
|
||||
|
||||
# Save the finetuned model
|
||||
booster.save_model(model, args.output_path)
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue