From 0a94fcd3514a6f7d4f287bba614fda3fb12c8802 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 4 Sep 2023 21:46:29 +0800 Subject: [PATCH] [shardformer] update bert finetune example with HybridParallelPlugin (#4584) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * [shardformer] fix opt test hanging * fix * test * test * [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py * [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516) * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom * [shardformer] fix emerged bugs after updating transformers (#4526) * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code * [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] fix submodule replacement bug when enabling pp (#4544) * [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * rebase feature/shardformer * update pipeline * [shardformer] fix * [shardformer] fix * [shardformer] bert finetune fix * [shardformer] add all_reduce operation to loss add all_reduce operation to loss * [shardformer] make compatible with pytree. make compatible with pytree. * [shardformer] disable tp disable tp * [shardformer] add 3d plugin to ci test * [shardformer] update num_microbatches to None * [shardformer] update microbatchsize * [shardformer] update assert * update scheduler * update scheduler --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Baizhou Zhang --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- examples/language/bert/finetune.py | 163 ++++++++++++++---- examples/language/bert/test_ci.sh | 2 +- 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c83e51b26..8ad9b7956 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -325,7 +325,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.schedule = None assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ec53a6771..5db1c7f30 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -46,6 +46,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None + self._use_microbatch_size = num_microbatches is None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -60,7 +61,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - if self.num_microbatches is not None: + if not self._use_microbatch_size: assert self.batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b209ffde8..b9a3d5753 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,12 +1,14 @@ import argparse -from typing import List, Union +from contextlib import nullcontext +from typing import Callable, List, Union import evaluate import torch import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Optimizer +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( @@ -18,8 +20,9 @@ from transformers import ( import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -32,14 +35,26 @@ LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 WARMUP_FRACTION = 0.1 +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + def move_to_cuda(batch): return {k: v.cuda() for k, v in batch.items()} @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, - eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + optimizer, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -47,23 +62,66 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - labels = batch["labels"] - - metric.add_batch(predictions=preds, references=labels) + batch_size = batch["input_ids"].shape[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + #TODO pass dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + + if booster.plugin.stage_manager.is_last_stage(): + val_loss = outputs["loss"] + + logits = outputs["outputs"]["logits"] + + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast(preds, src=current_rank, group=pp_group) + dist.broadcast(val_loss, src=current_rank, group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + val_loss = torch.empty((1,), device=get_current_device()) + preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device()) + + dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group) + dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group) + + accum_loss.add_(val_loss) + metric.add_batch(predictions=preds, references=labels) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) - if coordinator.is_master(): + if coordinator.is_master() and results is not None: results['loss'] = accum_loss.item() / coordinator.world_size + return results if isinstance(test_dataloader, DataLoader): @@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + is_pp_last_stage = hasattr( + booster.plugin, + "stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage() + with tqdm(train_dataloader, + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) - outputs = model(**batch) - loss = outputs[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + #TODO pass train_dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if booster.plugin.stage_manager.is_last_stage(): + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) - # Backward and optimize - booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() lr_scheduler.step() - # Print log info - pbar.set_postfix({'loss': loss.item()}) - def main(): # ============================== @@ -107,7 +183,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], help="plugin to use") parser.add_argument( "--model_type", @@ -116,6 +192,7 @@ def main(): help="bert or albert", ) parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == 'bert': @@ -145,6 +222,17 @@ def main(): plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) booster = Booster(plugin=plugin, **booster_kwargs) @@ -165,8 +253,9 @@ def main(): # bert pretrained model cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: @@ -196,19 +285,27 @@ def main(): num_training_steps=total_steps, ) + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, booster, coordinator) if coordinator.is_master(): print(results) diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index 7fc6daabb..394ff831b 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -3,6 +3,6 @@ set -xe pip install -r requirements.txt -for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" done