[example] update vit example for hybrid parallel plugin (#4641)

* update vit example for hybrid plugin

* reset tp/pp size

* fix dataloader iteration bug

* update optimizer passing in evaluation/add grad_accum

* change criterion

* wrap tqdm

* change grad_accum to grad_checkpoint

* fix pbar
pull/4659/head
Baizhou Zhang 2023-09-07 17:38:45 +08:00 committed by GitHub
parent 660eed9124
commit 295b38fecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 246 additions and 192 deletions

View File

@ -884,6 +884,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if self.gradient_checkpointing and self.training:
if use_cache:
logger = logging.get_logger(__name__)
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False

View File

@ -1,9 +1,9 @@
import logging
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions is not None:
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
output_attentions = None
if output_hidden_states is not None:
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
output_hidden_states = None
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head

View File

@ -3,7 +3,7 @@
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
In our example, we are using pretrained weights of ViT loaded from HuggingFace.
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel).
## Run Demo
@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script:
```bash
bash run_benchmark.sh
```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.

View File

@ -1,124 +1,82 @@
from colossalai import get_default_parser
def parse_demo_args():
parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="google/vit-base-patch16-224",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--output_path",
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning."
)
parser.add_argument("--model_name_or_path",
type=str,
default="google/vit-base-patch16-224",
help="Path to pretrained model or model identifier from huggingface.co/models.")
parser.add_argument("--output_path",
type=str,
default="./output_model",
help="The path of your saved model after finetuning.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--num_epoch",
type=int,
default=3,
help="Number of epochs."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=3e-4,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.3,
help="Ratio of warmup steps against total training steps."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.1,
help="Weight decay to use."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
help=
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.")
parser.add_argument("--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--tp_size",
type=int,
default=1,
help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.")
parser.add_argument("--pp_size",
type=int,
default=1,
help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.")
parser.add_argument("--learning_rate",
type=float,
default=3e-4,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--warmup_ratio",
type=float,
default=0.3,
help="Ratio of warmup steps against total training steps.")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
args = parser.parse_args()
return args
def parse_benchmark_args():
parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="google/vit-base-patch16-224",
help="Path to a pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument("--model_name_or_path",
type=str,
default="google/vit-base-patch16-224",
help="Path to a pretrained model or model identifier from huggingface.co/models.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--num_labels",
type=int,
default=10,
help="Number of labels for classification."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--mem_cap",
type=int,
default=0,
help="Limit on the usage of space for each GPU (in GB)."
help=
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
)
parser.add_argument("--batch_size",
type=int,
default=8,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.")
parser.add_argument("--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
args = parser.parse_args()
return args
return args

View File

@ -1,32 +1,38 @@
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from torch.utils.data import Dataset
class BeansDataset(Dataset):
def __init__(self, image_processor, split='train'):
def __init__(self, image_processor, tp_size=1, split='train'):
super().__init__()
self.image_processor = image_processor
self.ds = load_dataset('beans')[split]
self.label_names = self.ds.features['labels'].names
while len(self.label_names) % tp_size != 0:
# ensure that the number of labels is multiple of tp_size
self.label_names.append(f"pad_label_{len(self.label_names)}")
self.num_labels = len(self.label_names)
self.inputs = []
for example in self.ds:
self.inputs.append(self.process_example(example))
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx]
def process_example(self, example):
input = self.image_processor(example['image'], return_tensors='pt')
input['labels'] = example['labels']
return input
def beans_collator(batch):
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
return {
'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
}

View File

@ -5,23 +5,20 @@ export BS=8
export MEMCAP=0
export GPUNUM=1
for BS in 8 32 128
for BS in 8 32
do
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
do
for GPUNUM in 1 4
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
MODEL_PATH="google/vit-base-patch16-224"
torchrun \
--standalone \
--nproc_per_node ${GPUNUM} \
--nproc_per_node 4 \
vit_benchmark.py \
--model_name_or_path ${MODEL_PATH} \
--mem_cap ${MEMCAP} \
--plugin ${PLUGIN} \
--batch_size ${BS}
done
done
done

View File

@ -5,16 +5,21 @@ pip install -r requirements.txt
MODEL="google/vit-base-patch16-224"
# path for saving model
OUTPUT_PATH="./output_model.bin"
OUTPUT_PATH="./output_model"
# plugin(training strategy)
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel"
PLUGIN="gemini"
#PLUGIN="hybrid_parallel"
# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel"
TP_SIZE=2
PP_SIZE=2
# number of gpus to use
GPUNUM=4
# batch size per gpu
# batch size per data parallel group
BS=16
# learning rate
@ -38,6 +43,8 @@ torchrun \
--output_path ${OUTPUT_PATH} \
--plugin ${PLUGIN} \
--batch_size ${BS} \
--tp_size ${TP_SIZE} \
--pp_size ${PP_SIZE} \
--num_epoch ${EPOCH} \
--learning_rate ${LR} \
--weight_decay ${WEIGHT_DECAY} \

View File

@ -2,18 +2,15 @@ set -xe
pip install -r requirements.txt
BS=8
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
do
for GPUNUM in 1 4
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
torchrun \
--standalone \
--nproc_per_node ${GPUNUM} \
--nproc_per_node 4 \
vit_benchmark.py \
--model_name_or_path "google/vit-base-patch16-224" \
--plugin ${PLUGIN} \
--batch_size ${BS}
done
done

View File

@ -1,14 +1,14 @@
import time
import torch
import tqdm
import transformers
from args import parse_benchmark_args
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
num /= factor
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size,
num_channels,
height,
@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels
return dict(pixel_values=pixel_values, labels=labels)
def colo_memory_cap(size_in_GB):
@ -70,7 +70,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
@ -82,34 +83,57 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
plugin = HybridParallelPlugin(tp_size=2,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision='fp16',
initial_scale=1)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
for _ in pbar:
optimizer.zero_grad()
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
optimizer.zero_grad()
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs['loss']
booster.backward(loss, optimizer)
optimizer.step()
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
torch.cuda.synchronize()
progress_bar.update(1)
optimizer.step()
torch.cuda.synchronize()
# Compute Statistics
end_time = time.time()
@ -124,6 +148,8 @@ def main():
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
torch.cuda.empty_cache()
if __name__ == "__main__":
main()

View File

@ -1,70 +1,111 @@
from typing import Any, Callable, Iterator
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from args import parse_demo_args
from data import BeansDataset, beans_collator
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
data_iter: Iterator, booster: Booster):
if optimizer is not None:
optimizer.zero_grad()
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# run pipeline forward backward when enabling pp in hybrid parallel plugin
output_dict = booster.execute_pipeline(data_iter,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss, outputs = output_dict['loss'], output_dict['outputs']
else:
batch = next(data_iter)
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = criterion(outputs, None)
if optimizer is not None:
booster.backward(loss, optimizer)
return loss, outputs
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
torch.cuda.synchronize()
num_steps = len(dataloader)
data_iter = iter(dataloader)
enable_pbar = coordinator.is_master()
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
tp_rank = dist.get_rank(booster.plugin.tp_group)
dp_rank = dist.get_rank(booster.plugin.dp_group)
enable_pbar = tp_rank == 0 and dp_rank == 0 \
and booster.plugin.stage_manager.is_last_stage()
model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Foward
optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = outputs['loss']
# Backward
booster.backward(loss, optimizer)
with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar:
for _ in pbar:
loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
optimizer.step()
lr_scheduler.step()
# Print batch loss
pbar.set_postfix({'loss': loss.item()})
if enable_pbar:
pbar.set_postfix({'loss': loss.item()})
@torch.no_grad()
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor],
eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
torch.cuda.synchronize()
model.eval()
accum_loss = torch.zeros(1, device=get_current_device())
total_num = torch.zeros(1, device=get_current_device())
accum_correct = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
total_num = torch.zeros(1, device=torch.cuda.current_device())
accum_correct = torch.zeros(1, device=torch.cuda.current_device())
for batch in eval_dataloader:
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss += (val_loss / len(eval_dataloader))
if num_labels > 1:
preds = torch.argmax(logits, dim=1)
elif num_labels == 1:
preds = logits.squeeze()
loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster)
labels = batch["labels"]
total_num += batch["labels"].shape[0]
accum_correct += (torch.sum(preds == labels))
to_accum = True
if isinstance(booster.plugin, HybridParallelPlugin):
# when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0
to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0)
if booster.plugin.pp_size > 1:
to_accum = to_accum and booster.plugin.stage_manager.is_last_stage()
if to_accum:
accum_loss += (loss / len(eval_dataloader))
logits = outputs["logits"]
preds = torch.argmax(logits, dim=1)
labels = batch["labels"]
total_num += batch["labels"].shape[0]
accum_correct += (torch.sum(preds == labels))
dist.all_reduce(accum_loss)
dist.all_reduce(total_num)
@ -94,14 +135,20 @@ def main():
else:
transformers.utils.logging.set_verbosity_error()
# Reset tp_size and pp_size to 1 if not using hybrid parallel.
if args.plugin != 'hybrid_parallel':
args.tp_size = 1
args.pp_size = 1
# Prepare Dataset
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
train_dataset = BeansDataset(image_processor, split='train')
eval_dataset = BeansDataset(image_processor, split='validation')
train_dataset = BeansDataset(image_processor, args.tp_size, split='train')
eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation')
num_labels = train_dataset.num_labels
# Load pretrained ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
config.num_labels = train_dataset.num_labels
config.num_labels = num_labels
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
@ -110,7 +157,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
@ -122,6 +170,16 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
plugin = HybridParallelPlugin(tp_size=args.tp_size,
pp_size=args.pp_size,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision='fp16',
initial_scale=1)
else:
raise ValueError(f"Plugin with name {args.plugin} is not supported!")
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare dataloader
@ -139,6 +197,10 @@ def main():
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set lr scheduler
total_steps = len(train_dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps)
@ -148,20 +210,21 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
# Finetuning
logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator)
evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator)
logger.info(f"Finish finetuning", ranks=[0])
# Save the finetuned model
booster.save_model(model, args.output_path)
booster.save_model(model, args.output_path, shard=True)
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])