From 295b38fecf3358b577b2e8c21eaf363d600dc38e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 17:38:45 +0800 Subject: [PATCH] [example] update vit example for hybrid parallel plugin (#4641) * update vit example for hybrid plugin * reset tp/pp size * fix dataloader iteration bug * update optimizer passing in evaluation/add grad_accum * change criterion * wrap tqdm * change grad_accum to grad_checkpoint * fix pbar --- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/vit.py | 21 ++-- examples/images/vit/README.md | 4 +- examples/images/vit/args.py | 156 +++++++++--------------- examples/images/vit/data.py | 22 ++-- examples/images/vit/run_benchmark.sh | 11 +- examples/images/vit/run_demo.sh | 13 +- examples/images/vit/test_ci.sh | 7 +- examples/images/vit/vit_benchmark.py | 58 ++++++--- examples/images/vit/vit_train_demo.py | 145 +++++++++++++++------- 10 files changed, 246 insertions(+), 192 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8ed367b25..9eb58df4d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -884,6 +884,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if self.gradient_checkpointing and self.training: if use_cache: + logger = logging.get_logger(__name__) logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 9fc0b7488..2ce52163a 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,9 +1,9 @@ -import logging import math from typing import Dict, List, Optional, Set, Tuple, Union import torch from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if output_attentions is not None: - logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') - output_attentions = None - if output_hidden_states is not None: - logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') - output_hidden_states = None + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md index 7c4147b76..33c6454ad 100644 --- a/examples/images/vit/README.md +++ b/examples/images/vit/README.md @@ -3,7 +3,7 @@ Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time. In our example, we are using pretrained weights of ViT loaded from HuggingFace. -We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. +We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel). ## Run Demo @@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script: ```bash bash run_benchmark.sh ``` -The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. \ No newline at end of file +The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e4a873a9e..e6c52c4e9 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -1,124 +1,82 @@ from colossalai import get_default_parser + def parse_demo_args(): parser = get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to pretrained model or model identifier from huggingface.co/models." - ) - parser.add_argument( - "--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning." - ) + parser.add_argument("--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--output_path", + type=str, + default="./output_model", + help="The path of your saved model after finetuning.") parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." - ) - parser.add_argument( - "--num_epoch", - type=int, - default=3, - help="Number of epochs." - ) - parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader." - ) - parser.add_argument( - "--learning_rate", - type=float, - default=3e-4, - help="Initial learning rate (after the potential warmup period) to use." - ) - parser.add_argument( - "--warmup_ratio", - type=float, - default=0.3, - help="Ratio of warmup steps against total training steps." - ) - parser.add_argument( - "--weight_decay", - type=float, - default=0.1, - help="Weight decay to use." - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="A seed for reproducible training." + help= + "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." ) + parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") + parser.add_argument("--batch_size", + type=int, + default=32, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--tp_size", + type=int, + default=1, + help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") + parser.add_argument("--pp_size", + type=int, + default=1, + help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") + parser.add_argument("--learning_rate", + type=float, + default=3e-4, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--warmup_ratio", + type=float, + default=0.3, + help="Ratio of warmup steps against total training steps.") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") args = parser.parse_args() return args + def parse_benchmark_args(): parser = get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to a pretrained model or model identifier from huggingface.co/models." - ) + parser.add_argument("--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to a pretrained model or model identifier from huggingface.co/models.") parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader." - ) - parser.add_argument( - "--num_labels", - type=int, - default=10, - help="Number of labels for classification." - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use." - ) - parser.add_argument( - "--weight_decay", - type=float, - default=0.0, - help="Weight decay to use." - ) - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform." - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="A seed for reproducible training." - ) - parser.add_argument( - "--mem_cap", - type=int, - default=0, - help="Limit on the usage of space for each GPU (in GB)." + help= + "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." ) + parser.add_argument("--batch_size", + type=int, + default=8, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") + parser.add_argument("--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") + parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") args = parser.parse_args() - return args \ No newline at end of file + return args diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py index 00fde707b..77a8ad525 100644 --- a/examples/images/vit/data.py +++ b/examples/images/vit/data.py @@ -1,32 +1,38 @@ import torch -from torch.utils.data import Dataset from datasets import load_dataset +from torch.utils.data import Dataset + class BeansDataset(Dataset): - - def __init__(self, image_processor, split='train'): + + def __init__(self, image_processor, tp_size=1, split='train'): super().__init__() self.image_processor = image_processor self.ds = load_dataset('beans')[split] self.label_names = self.ds.features['labels'].names + while len(self.label_names) % tp_size != 0: + # ensure that the number of labels is multiple of tp_size + self.label_names.append(f"pad_label_{len(self.label_names)}") self.num_labels = len(self.label_names) self.inputs = [] for example in self.ds: self.inputs.append(self.process_example(example)) - + def __len__(self): return len(self.inputs) def __getitem__(self, idx): return self.inputs[idx] - + def process_example(self, example): input = self.image_processor(example['image'], return_tensors='pt') input['labels'] = example['labels'] return input - + def beans_collator(batch): - return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), - 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)} + return { + 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), + 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64) + } diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh index 2487bf81e..41eab9c5a 100644 --- a/examples/images/vit/run_benchmark.sh +++ b/examples/images/vit/run_benchmark.sh @@ -5,23 +5,20 @@ export BS=8 export MEMCAP=0 export GPUNUM=1 -for BS in 8 32 128 +for BS in 8 32 do -for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" -do -for GPUNUM in 1 4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" do MODEL_PATH="google/vit-base-patch16-224" torchrun \ --standalone \ - --nproc_per_node ${GPUNUM} \ + --nproc_per_node 4 \ vit_benchmark.py \ --model_name_or_path ${MODEL_PATH} \ --mem_cap ${MEMCAP} \ --plugin ${PLUGIN} \ --batch_size ${BS} - -done + done done diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh index 2d140dd6e..9efe14759 100644 --- a/examples/images/vit/run_demo.sh +++ b/examples/images/vit/run_demo.sh @@ -5,16 +5,21 @@ pip install -r requirements.txt MODEL="google/vit-base-patch16-224" # path for saving model -OUTPUT_PATH="./output_model.bin" +OUTPUT_PATH="./output_model" # plugin(training strategy) -# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" +# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel" PLUGIN="gemini" +#PLUGIN="hybrid_parallel" + +# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel" +TP_SIZE=2 +PP_SIZE=2 # number of gpus to use GPUNUM=4 -# batch size per gpu +# batch size per data parallel group BS=16 # learning rate @@ -38,6 +43,8 @@ torchrun \ --output_path ${OUTPUT_PATH} \ --plugin ${PLUGIN} \ --batch_size ${BS} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ --num_epoch ${EPOCH} \ --learning_rate ${LR} \ --weight_decay ${WEIGHT_DECAY} \ diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh index 8606015c0..570147606 100644 --- a/examples/images/vit/test_ci.sh +++ b/examples/images/vit/test_ci.sh @@ -2,18 +2,15 @@ set -xe pip install -r requirements.txt BS=8 -for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" -do -for GPUNUM in 1 4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" do torchrun \ --standalone \ - --nproc_per_node ${GPUNUM} \ + --nproc_per_node 4 \ vit_benchmark.py \ --model_name_or_path "google/vit-base-patch16-224" \ --plugin ${PLUGIN} \ --batch_size ${BS} done -done diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index c2293b96a..d822fe23e 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -1,14 +1,14 @@ import time import torch -import tqdm import transformers from args import parse_benchmark_args +from tqdm import tqdm from transformers import ViTConfig, ViTForImageClassification import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -24,7 +24,7 @@ def format_num(num: int, bytes=False): num /= factor -def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): +def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224): pixel_values = torch.randn(batch_size, num_channels, height, @@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): device=torch.cuda.current_device(), dtype=torch.float) labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) - return pixel_values, labels + return dict(pixel_values=pixel_values, labels=labels) def colo_memory_cap(size_in_GB): @@ -70,7 +70,8 @@ def main(): logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing - model.gradient_checkpointing_enable() + if args.grad_checkpoint: + model.gradient_checkpointing_enable() # Set plugin booster_kwargs = {} @@ -82,34 +83,57 @@ def main(): plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + plugin = HybridParallelPlugin(tp_size=2, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision='fp16', + initial_scale=1) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) + # Set criterion (loss function) + def criterion(outputs, inputs): + return outputs.loss + # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, _, _ = booster.boost(model, optimizer) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion) # Start training. logger.info(f"Start testing", ranks=[0]) - progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) torch.cuda.synchronize() model.train() start_time = time.time() - for _ in range(args.max_train_steps): + with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar: + for _ in pbar: + optimizer.zero_grad() + batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) - pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) - optimizer.zero_grad() - outputs = model(pixel_values=pixel_values, labels=labels) - loss = outputs['loss'] - booster.backward(loss, optimizer) - optimizer.step() + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + # run pipeline forward backward + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + else: + outputs = model(**batch) + loss = criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) - torch.cuda.synchronize() - progress_bar.update(1) + optimizer.step() + + torch.cuda.synchronize() # Compute Statistics end_time = time.time() @@ -124,6 +148,8 @@ def main(): f"maximum memory usage per gpu: {max_mem}.", ranks=[0]) + torch.cuda.empty_cache() + if __name__ == "__main__": main() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 4dc0f67f4..206d8694b 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -1,70 +1,111 @@ +from typing import Any, Callable, Iterator + import torch import torch.distributed as dist +import torch.nn as nn import transformers from args import parse_demo_args from data import BeansDataset, beans_collator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): +def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, booster: Booster): + if optimizer is not None: + optimizer.zero_grad() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # run pipeline forward backward when enabling pp in hybrid parallel plugin + output_dict = booster.execute_pipeline(data_iter, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + loss, outputs = output_dict['loss'], output_dict['outputs'] + else: + batch = next(data_iter) + batch = move_to_cuda(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = criterion(outputs, None) + if optimizer is not None: + booster.backward(loss, optimizer) + + return loss, outputs + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): torch.cuda.synchronize() + + num_steps = len(dataloader) + data_iter = iter(dataloader) + enable_pbar = coordinator.is_master() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar + tp_rank = dist.get_rank(booster.plugin.tp_group) + dp_rank = dist.get_rank(booster.plugin.dp_group) + enable_pbar = tp_rank == 0 and dp_rank == 0 \ + and booster.plugin.stage_manager.is_last_stage() + model.train() - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - - for batch in pbar: - - # Foward - optimizer.zero_grad() - batch = move_to_cuda(batch, torch.cuda.current_device()) - outputs = model(**batch) - loss = outputs['loss'] - - # Backward - booster.backward(loss, optimizer) + with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: + for _ in pbar: + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() lr_scheduler.step() # Print batch loss - pbar.set_postfix({'loss': loss.item()}) + if enable_pbar: + pbar.set_postfix({'loss': loss.item()}) @torch.no_grad() -def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): +def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], + eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + torch.cuda.synchronize() model.eval() - accum_loss = torch.zeros(1, device=get_current_device()) - total_num = torch.zeros(1, device=get_current_device()) - accum_correct = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + total_num = torch.zeros(1, device=torch.cuda.current_device()) + accum_correct = torch.zeros(1, device=torch.cuda.current_device()) for batch in eval_dataloader: batch = move_to_cuda(batch, torch.cuda.current_device()) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss += (val_loss / len(eval_dataloader)) - if num_labels > 1: - preds = torch.argmax(logits, dim=1) - elif num_labels == 1: - preds = logits.squeeze() + loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster) - labels = batch["labels"] - total_num += batch["labels"].shape[0] - accum_correct += (torch.sum(preds == labels)) + to_accum = True + if isinstance(booster.plugin, HybridParallelPlugin): + # when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0 + to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0) + if booster.plugin.pp_size > 1: + to_accum = to_accum and booster.plugin.stage_manager.is_last_stage() + + if to_accum: + accum_loss += (loss / len(eval_dataloader)) + logits = outputs["logits"] + preds = torch.argmax(logits, dim=1) + + labels = batch["labels"] + total_num += batch["labels"].shape[0] + accum_correct += (torch.sum(preds == labels)) dist.all_reduce(accum_loss) dist.all_reduce(total_num) @@ -94,14 +135,20 @@ def main(): else: transformers.utils.logging.set_verbosity_error() + # Reset tp_size and pp_size to 1 if not using hybrid parallel. + if args.plugin != 'hybrid_parallel': + args.tp_size = 1 + args.pp_size = 1 + # Prepare Dataset image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) - train_dataset = BeansDataset(image_processor, split='train') - eval_dataset = BeansDataset(image_processor, split='validation') + train_dataset = BeansDataset(image_processor, args.tp_size, split='train') + eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation') + num_labels = train_dataset.num_labels # Load pretrained ViT model config = ViTConfig.from_pretrained(args.model_name_or_path) - config.num_labels = train_dataset.num_labels + config.num_labels = num_labels config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} model = ViTForImageClassification.from_pretrained(args.model_name_or_path, @@ -110,7 +157,8 @@ def main(): logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing - model.gradient_checkpointing_enable() + if args.grad_checkpoint: + model.gradient_checkpointing_enable() # Set plugin booster_kwargs = {} @@ -122,6 +170,16 @@ def main(): plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + plugin = HybridParallelPlugin(tp_size=args.tp_size, + pp_size=args.pp_size, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision='fp16', + initial_scale=1) + else: + raise ValueError(f"Plugin with name {args.plugin} is not supported!") logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare dataloader @@ -139,6 +197,10 @@ def main(): # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) + # Set criterion (loss function) + def criterion(outputs, inputs): + return outputs.loss + # Set lr scheduler total_steps = len(train_dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) @@ -148,20 +210,21 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=train_dataloader, + lr_scheduler=lr_scheduler) # Finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) - evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator) + train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator) + evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator) logger.info(f"Finish finetuning", ranks=[0]) # Save the finetuned model - booster.save_model(model, args.output_path) + booster.save_model(model, args.output_path, shard=True) logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])