[shardformer] update bert finetune example with HybridParallelPlugin (#4584)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* [shardformer] fix opt test hanging

* fix

* test

* test

* [shardformer] zero1+pp and the corresponding tests (#4517)

* pause

* finish pp+zero1

* Update test_shard_vit.py

* [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom

* [shardformer] fix emerged bugs after updating transformers (#4526)

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] Add overlap support for gpt2 (#4535)

* add overlap support for gpt2

* remove unused code

* remove unused code

* [shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] fix submodule replacement bug when enabling pp (#4544)

* [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] add bert finetune example

* [shardformer] fix epoch change

* [shardformer] broadcast add pp group

* rebase feature/shardformer

* update pipeline

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert finetune fix

* [shardformer] add all_reduce operation to loss

add all_reduce operation to loss

* [shardformer] make compatible with pytree.

make compatible with pytree.

* [shardformer] disable tp

disable tp

* [shardformer] add 3d plugin to ci test

* [shardformer] update num_microbatches to None

* [shardformer] update microbatchsize

* [shardformer] update assert

* update scheduler

* update scheduler

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
pull/4606/head
flybird11111 1 year ago committed by GitHub
parent 24c0768795
commit 0a94fcd351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -325,7 +325,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,

@ -46,6 +46,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@ -60,7 +61,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if self.num_microbatches is not None:
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches

@ -1,12 +1,14 @@
import argparse
from typing import List, Union
from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
@ -18,8 +20,9 @@ from transformers import (
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@ -32,14 +35,26 @@ LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
output_transform_fn = lambda x: x
criterion = lambda x: x.loss
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
task_name: str,
eval_splits: List[str],
booster: Booster,
coordinator: DistCoordinator,
):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
@ -47,23 +62,66 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
metric.add_batch(predictions=preds, references=labels)
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]
logits = outputs["outputs"]["logits"]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)
metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)
accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)
else:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size
return results
if isinstance(test_dataloader, DataLoader):
@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
return final_results
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({'loss': loss.item()})
# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Print log info
pbar.set_postfix({'loss': loss.item()})
def main():
# ==============================
@ -107,7 +183,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
@ -116,6 +192,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == 'bert':
@ -145,6 +222,17 @@ def main():
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)
booster = Booster(plugin=plugin, **booster_kwargs)
@ -165,8 +253,9 @@ def main():
# bert pretrained model
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
@ -196,19 +285,27 @@ def main():
num_training_steps=total_steps,
)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)
# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)
if coordinator.is_master():
print(results)

@ -3,6 +3,6 @@ set -xe
pip install -r requirements.txt
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done

Loading…
Cancel
Save