[example] add finetune bert with booster example (#3693)

pull/3695/head
Hongxin Liu 2023-05-06 11:53:13 +08:00 committed by GitHub
parent 65bdc3159f
commit d556648885
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 364 additions and 0 deletions

View File

@ -0,0 +1,33 @@
# Finetune BERT on GLUE
## 🚀 Quick Start
This example provides a training script, which provides an example of finetuning BERT on GLUE dataset.
- Training Arguments
- `-t`, `--task`: GLUE task to run. Defaults to `mrpc`.
- `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `gemini`, `low_level_zero`. Defaults to `torch_ddp`.
- `--target_f1`: Target f1 score. Raise exception if not reached. Defaults to `None`.
### Train
```bash
# train with torch DDP with fp32
colossalai run --nproc_per_node 4 finetune.py
# train with torch DDP with mixed precision training
colossalai run --nproc_per_node 4 finetune.py -p torch_ddp_fp16
# train with gemini
colossalai run --nproc_per_node 4 finetune.py -p gemini
# train with low level zero
colossalai run --nproc_per_node 4 finetune.py -p low_level_zero
```
Expected F1-score will be:
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Gemini | Booster Low Level Zero |
| ----------------- | ------------------------ | --------------------- | --------------------- |--------------- | ---------------------- |
| bert-base-uncased | 0.86 | 0.88 | 0.87 | 0.88 | 0.89 |

View File

@ -0,0 +1,127 @@
import datasets
from transformers import AutoTokenizer, PreTrainedTokenizer
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
class GLUEDataBuilder:
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}
glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}
loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
plugin: DPPluginBase,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.plugin = plugin
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
self.setup()
def setup(self):
self.dataset = datasets.load_dataset("glue", self.task_name)
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return self.plugin.prepare_train_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def convert_to_features(self, example_batch):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
max_length=self.max_seq_length,
padding='max_length',
truncation=True)
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features

View File

@ -0,0 +1,198 @@
import argparse
from typing import List, Union
import datasets
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
results['loss'] = accum_loss.item() / coordinator.world_size
return results
if isinstance(test_dataloader, DataLoader):
return evaluate_subset(test_dataloader)
else:
assert len(test_dataloader) == len(eval_splits)
final_results = {}
for split, sub_loader in zip(eval_splits, test_dataloader):
results = evaluate_subset(sub_loader)
final_results.update({f'{k}_{split}': v for k, v in results.items()})
return final_results
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Print log info
pbar.set_postfix({'loss': loss.item()})
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
args = parser.parse_args()
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
lr = LEARNING_RATE * coordinator.world_size
model_name = 'bert-base-uncased'
# ==============================
# Instantiate Plugin and Booster
# ==============================
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# ==============================
# Prepare Dataloader
# ==============================
data_builder = GLUEDataBuilder(model_name,
plugin,
args.task,
train_batch_size=BATCH_SIZE,
eval_batch_size=BATCH_SIZE)
train_dataloader = data_builder.train_dataloader()
test_dataloader = data_builder.test_dataloader()
# ====================================
# Prepare model, optimizer
# ====================================
# bert pretrained model
config = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
model = BertForSequenceClassification.from_pretrained(model_name, config=config)
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": WEIGHT_DECAY,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
# lr scheduler
total_steps = len(train_dataloader) * NUM_EPOCHS
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
if coordinator.is_master():
print(results)
if args.target_f1 is not None and 'f1' in results:
assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
if __name__ == '__main__':
main()

View File

@ -0,0 +1,6 @@
#!/bin/bash
set -xe
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin
done