[example] update vit example for hybrid parallel plugin (#4641)

* update vit example for hybrid plugin

* reset tp/pp size

* fix dataloader iteration bug

* update optimizer passing in evaluation/add grad_accum

* change criterion

* wrap tqdm

* change grad_accum to grad_checkpoint

* fix pbar
pull/4659/head
Baizhou Zhang 2023-09-07 17:38:45 +08:00 committed by GitHub
parent 660eed9124
commit 295b38fecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 246 additions and 192 deletions

View File

@ -884,6 +884,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger = logging.get_logger(__name__)
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False use_cache = False

View File

@ -1,9 +1,9 @@
import logging
import math import math
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import torch import torch
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions is not None: logger = logging.get_logger(__name__)
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
output_attentions = None # Preprocess passed in arguments
if output_hidden_states is not None: if output_attentions:
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_hidden_states = None output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head

View File

@ -3,7 +3,7 @@
Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time. Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
In our example, we are using pretrained weights of ViT loaded from HuggingFace. In our example, we are using pretrained weights of ViT loaded from HuggingFace.
We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel).
## Run Demo ## Run Demo
@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script:
```bash ```bash
bash run_benchmark.sh bash run_benchmark.sh
``` ```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.

View File

@ -1,124 +1,82 @@
from colossalai import get_default_parser from colossalai import get_default_parser
def parse_demo_args(): def parse_demo_args():
parser = get_default_parser() parser = get_default_parser()
parser.add_argument( parser.add_argument("--model_name_or_path",
"--model_name_or_path", type=str,
type=str, default="google/vit-base-patch16-224",
default="google/vit-base-patch16-224", help="Path to pretrained model or model identifier from huggingface.co/models.")
help="Path to pretrained model or model identifier from huggingface.co/models." parser.add_argument("--output_path",
) type=str,
parser.add_argument( default="./output_model",
"--output_path", help="The path of your saved model after finetuning.")
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning."
)
parser.add_argument( parser.add_argument(
"--plugin", "--plugin",
type=str, type=str,
default="gemini", default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." help=
) "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
parser.add_argument(
"--num_epoch",
type=int,
default=3,
help="Number of epochs."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=3e-4,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.3,
help="Ratio of warmup steps against total training steps."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.1,
help="Weight decay to use."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
) )
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.")
parser.add_argument("--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--tp_size",
type=int,
default=1,
help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.")
parser.add_argument("--pp_size",
type=int,
default=1,
help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.")
parser.add_argument("--learning_rate",
type=float,
default=3e-4,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--warmup_ratio",
type=float,
default=0.3,
help="Ratio of warmup steps against total training steps.")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
args = parser.parse_args() args = parser.parse_args()
return args return args
def parse_benchmark_args(): def parse_benchmark_args():
parser = get_default_parser() parser = get_default_parser()
parser.add_argument( parser.add_argument("--model_name_or_path",
"--model_name_or_path", type=str,
type=str, default="google/vit-base-patch16-224",
default="google/vit-base-patch16-224", help="Path to a pretrained model or model identifier from huggingface.co/models.")
help="Path to a pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument( parser.add_argument(
"--plugin", "--plugin",
type=str, type=str,
default="gemini", default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." help=
) "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--num_labels",
type=int,
default=10,
help="Number of labels for classification."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--mem_cap",
type=int,
default=0,
help="Limit on the usage of space for each GPU (in GB)."
) )
parser.add_argument("--batch_size",
type=int,
default=8,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.")
parser.add_argument("--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.")
parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
args = parser.parse_args() args = parser.parse_args()
return args return args

View File

@ -1,32 +1,38 @@
import torch import torch
from torch.utils.data import Dataset
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import Dataset
class BeansDataset(Dataset): class BeansDataset(Dataset):
def __init__(self, image_processor, split='train'): def __init__(self, image_processor, tp_size=1, split='train'):
super().__init__() super().__init__()
self.image_processor = image_processor self.image_processor = image_processor
self.ds = load_dataset('beans')[split] self.ds = load_dataset('beans')[split]
self.label_names = self.ds.features['labels'].names self.label_names = self.ds.features['labels'].names
while len(self.label_names) % tp_size != 0:
# ensure that the number of labels is multiple of tp_size
self.label_names.append(f"pad_label_{len(self.label_names)}")
self.num_labels = len(self.label_names) self.num_labels = len(self.label_names)
self.inputs = [] self.inputs = []
for example in self.ds: for example in self.ds:
self.inputs.append(self.process_example(example)) self.inputs.append(self.process_example(example))
def __len__(self): def __len__(self):
return len(self.inputs) return len(self.inputs)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.inputs[idx] return self.inputs[idx]
def process_example(self, example): def process_example(self, example):
input = self.image_processor(example['image'], return_tensors='pt') input = self.image_processor(example['image'], return_tensors='pt')
input['labels'] = example['labels'] input['labels'] = example['labels']
return input return input
def beans_collator(batch): def beans_collator(batch):
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), return {
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)} 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
}

View File

@ -5,23 +5,20 @@ export BS=8
export MEMCAP=0 export MEMCAP=0
export GPUNUM=1 export GPUNUM=1
for BS in 8 32 128 for BS in 8 32
do do
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
for GPUNUM in 1 4
do do
MODEL_PATH="google/vit-base-patch16-224" MODEL_PATH="google/vit-base-patch16-224"
torchrun \ torchrun \
--standalone \ --standalone \
--nproc_per_node ${GPUNUM} \ --nproc_per_node 4 \
vit_benchmark.py \ vit_benchmark.py \
--model_name_or_path ${MODEL_PATH} \ --model_name_or_path ${MODEL_PATH} \
--mem_cap ${MEMCAP} \ --mem_cap ${MEMCAP} \
--plugin ${PLUGIN} \ --plugin ${PLUGIN} \
--batch_size ${BS} --batch_size ${BS}
done
done done
done done

View File

@ -5,16 +5,21 @@ pip install -r requirements.txt
MODEL="google/vit-base-patch16-224" MODEL="google/vit-base-patch16-224"
# path for saving model # path for saving model
OUTPUT_PATH="./output_model.bin" OUTPUT_PATH="./output_model"
# plugin(training strategy) # plugin(training strategy)
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" # can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel"
PLUGIN="gemini" PLUGIN="gemini"
#PLUGIN="hybrid_parallel"
# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel"
TP_SIZE=2
PP_SIZE=2
# number of gpus to use # number of gpus to use
GPUNUM=4 GPUNUM=4
# batch size per gpu # batch size per data parallel group
BS=16 BS=16
# learning rate # learning rate
@ -38,6 +43,8 @@ torchrun \
--output_path ${OUTPUT_PATH} \ --output_path ${OUTPUT_PATH} \
--plugin ${PLUGIN} \ --plugin ${PLUGIN} \
--batch_size ${BS} \ --batch_size ${BS} \
--tp_size ${TP_SIZE} \
--pp_size ${PP_SIZE} \
--num_epoch ${EPOCH} \ --num_epoch ${EPOCH} \
--learning_rate ${LR} \ --learning_rate ${LR} \
--weight_decay ${WEIGHT_DECAY} \ --weight_decay ${WEIGHT_DECAY} \

View File

@ -2,18 +2,15 @@ set -xe
pip install -r requirements.txt pip install -r requirements.txt
BS=8 BS=8
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel"
do
for GPUNUM in 1 4
do do
torchrun \ torchrun \
--standalone \ --standalone \
--nproc_per_node ${GPUNUM} \ --nproc_per_node 4 \
vit_benchmark.py \ vit_benchmark.py \
--model_name_or_path "google/vit-base-patch16-224" \ --model_name_or_path "google/vit-base-patch16-224" \
--plugin ${PLUGIN} \ --plugin ${PLUGIN} \
--batch_size ${BS} --batch_size ${BS}
done done
done

View File

@ -1,14 +1,14 @@
import time import time
import torch import torch
import tqdm
import transformers import transformers
from args import parse_benchmark_args from args import parse_benchmark_args
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification from transformers import ViTConfig, ViTForImageClassification
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -24,7 +24,7 @@ def format_num(num: int, bytes=False):
num /= factor num /= factor
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size, pixel_values = torch.randn(batch_size,
num_channels, num_channels,
height, height,
@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.float) dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels return dict(pixel_values=pixel_values, labels=labels)
def colo_memory_cap(size_in_GB): def colo_memory_cap(size_in_GB):
@ -70,7 +70,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing # Enable gradient checkpointing
model.gradient_checkpointing_enable() if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin # Set plugin
booster_kwargs = {} booster_kwargs = {}
@ -82,34 +83,57 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
plugin = HybridParallelPlugin(tp_size=2,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision='fp16',
initial_scale=1)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
# Start training. # Start training.
logger.info(f"Start testing", ranks=[0]) logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
start_time = time.time() start_time = time.time()
for _ in range(args.max_train_steps): with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
for _ in pbar:
optimizer.zero_grad()
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
optimizer.zero_grad() # run pipeline forward backward
outputs = model(pixel_values=pixel_values, labels=labels) batch = iter([batch])
loss = outputs['loss'] outputs = booster.execute_pipeline(batch,
booster.backward(loss, optimizer) model,
optimizer.step() criterion,
optimizer,
return_loss=True,
return_outputs=True)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
torch.cuda.synchronize() optimizer.step()
progress_bar.update(1)
torch.cuda.synchronize()
# Compute Statistics # Compute Statistics
end_time = time.time() end_time = time.time()
@ -124,6 +148,8 @@ def main():
f"maximum memory usage per gpu: {max_mem}.", f"maximum memory usage per gpu: {max_mem}.",
ranks=[0]) ranks=[0])
torch.cuda.empty_cache()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,70 +1,111 @@
from typing import Any, Callable, Iterator
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import transformers import transformers
from args import parse_demo_args from args import parse_demo_args
from data import BeansDataset, beans_collator from data import BeansDataset, beans_collator
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device): def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()} return {k: v.to(device) for k, v in batch.items()}
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
data_iter: Iterator, booster: Booster):
if optimizer is not None:
optimizer.zero_grad()
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# run pipeline forward backward when enabling pp in hybrid parallel plugin
output_dict = booster.execute_pipeline(data_iter,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss, outputs = output_dict['loss'], output_dict['outputs']
else:
batch = next(data_iter)
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = criterion(outputs, None)
if optimizer is not None:
booster.backward(loss, optimizer)
return loss, outputs
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor],
lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
torch.cuda.synchronize() torch.cuda.synchronize()
num_steps = len(dataloader)
data_iter = iter(dataloader)
enable_pbar = coordinator.is_master()
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
# when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar
tp_rank = dist.get_rank(booster.plugin.tp_group)
dp_rank = dist.get_rank(booster.plugin.dp_group)
enable_pbar = tp_rank == 0 and dp_rank == 0 \
and booster.plugin.stage_manager.is_last_stage()
model.train() model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar:
for _ in pbar:
for batch in pbar: loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster)
# Foward
optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = outputs['loss']
# Backward
booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
# Print batch loss # Print batch loss
pbar.set_postfix({'loss': loss.item()}) if enable_pbar:
pbar.set_postfix({'loss': loss.item()})
@torch.no_grad() @torch.no_grad()
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor],
eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
torch.cuda.synchronize()
model.eval() model.eval()
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=torch.cuda.current_device())
total_num = torch.zeros(1, device=get_current_device()) total_num = torch.zeros(1, device=torch.cuda.current_device())
accum_correct = torch.zeros(1, device=get_current_device()) accum_correct = torch.zeros(1, device=torch.cuda.current_device())
for batch in eval_dataloader: for batch in eval_dataloader:
batch = move_to_cuda(batch, torch.cuda.current_device()) batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(**batch) loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster)
val_loss, logits = outputs[:2]
accum_loss += (val_loss / len(eval_dataloader))
if num_labels > 1:
preds = torch.argmax(logits, dim=1)
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"] to_accum = True
total_num += batch["labels"].shape[0] if isinstance(booster.plugin, HybridParallelPlugin):
accum_correct += (torch.sum(preds == labels)) # when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0
to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0)
if booster.plugin.pp_size > 1:
to_accum = to_accum and booster.plugin.stage_manager.is_last_stage()
if to_accum:
accum_loss += (loss / len(eval_dataloader))
logits = outputs["logits"]
preds = torch.argmax(logits, dim=1)
labels = batch["labels"]
total_num += batch["labels"].shape[0]
accum_correct += (torch.sum(preds == labels))
dist.all_reduce(accum_loss) dist.all_reduce(accum_loss)
dist.all_reduce(total_num) dist.all_reduce(total_num)
@ -94,14 +135,20 @@ def main():
else: else:
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Reset tp_size and pp_size to 1 if not using hybrid parallel.
if args.plugin != 'hybrid_parallel':
args.tp_size = 1
args.pp_size = 1
# Prepare Dataset # Prepare Dataset
image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
train_dataset = BeansDataset(image_processor, split='train') train_dataset = BeansDataset(image_processor, args.tp_size, split='train')
eval_dataset = BeansDataset(image_processor, split='validation') eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation')
num_labels = train_dataset.num_labels
# Load pretrained ViT model # Load pretrained ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path) config = ViTConfig.from_pretrained(args.model_name_or_path)
config.num_labels = train_dataset.num_labels config.num_labels = num_labels
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
model = ViTForImageClassification.from_pretrained(args.model_name_or_path, model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
@ -110,7 +157,8 @@ def main():
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing # Enable gradient checkpointing
model.gradient_checkpointing_enable() if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin # Set plugin
booster_kwargs = {} booster_kwargs = {}
@ -122,6 +170,16 @@ def main():
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
plugin = HybridParallelPlugin(tp_size=args.tp_size,
pp_size=args.pp_size,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision='fp16',
initial_scale=1)
else:
raise ValueError(f"Plugin with name {args.plugin} is not supported!")
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare dataloader # Prepare dataloader
@ -139,6 +197,10 @@ def main():
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set lr scheduler # Set lr scheduler
total_steps = len(train_dataloader) * args.num_epoch total_steps = len(train_dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps) num_warmup_steps = int(args.warmup_ratio * total_steps)
@ -148,20 +210,21 @@ def main():
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=train_dataloader, criterion=criterion,
lr_scheduler=lr_scheduler) dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
# Finetuning # Finetuning
logger.info(f"Start finetuning", ranks=[0]) logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch): for epoch in range(args.num_epoch):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator)
evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator) evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator)
logger.info(f"Finish finetuning", ranks=[0]) logger.info(f"Finish finetuning", ranks=[0])
# Save the finetuned model # Save the finetuned model
booster.save_model(model, args.output_path) booster.save_model(model, args.output_path, shard=True)
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])