[shardformer] write an shardformer example with bert finetuning (#4126)

* [shardformer] add benchmark of shardformer

* [shardformer] add benchmark of shardformer
pull/4157/head
jiangmingyan 2023-06-30 16:48:29 +08:00 committed by Frank Lee
parent ae035d305d
commit 7f9b30335b
5 changed files with 323 additions and 1 deletions

View File

@ -15,6 +15,7 @@
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
- [Shardformer Convergence](#shardformer-convergence)
## 🔗 Introduction
@ -324,3 +325,15 @@ class ShardFormer:
"""
...
```
### Shardformer Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| :-----: | :----: | :----: | :----: | :----: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -0,0 +1,146 @@
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
class GLUEDataBuilder:
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}
glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}
loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
plugin: DPPluginBase = None,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.plugin = plugin
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
self.setup()
def setup(self):
self.dataset = datasets.load_dataset("glue", self.task_name)
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
def val_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def test_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size)
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]
def convert_to_features(self, example_batch):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
max_length=self.max_seq_length,
padding='max_length',
truncation=True)
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features
def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False):
return DataLoader(dataset,
batch_size=batch_size,
sampler=None,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory)

View File

@ -0,0 +1,154 @@
import argparse
import math
from typing import Any, List, Union
import evaluate
import torch
import torch.distributed as dist
from data import GLUEDataBuilder
from torch import nn
from torch.optim import Adam, AdamW, Optimizer
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import ShardConfig, ShardFormer
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t
return tree_map(_to, x)
def train(args):
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()
# prepare for data and dataset
data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain,
task_name=args.task,
train_batch_size=args.batch_size,
eval_batch_size=args.batch_size)
train_dataloader = data_builder.train_dataloader()
test_dataloader = data_builder.test_dataloader()
if args.model == "bert":
cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels)
model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg)
model.to(torch.cuda.current_device())
# if multiple GPUs, shard the model
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.shard_model(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
lr_scheduler = get_linear_schedule_with_warmup(
optim,
num_warmup_steps=math.ceil(max_steps * args.warmup_fraction),
num_training_steps=max_steps,
)
fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size,
coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
if coordinator.is_master():
print(results)
if args.target_f1 is not None and 'f1' in results:
assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size,
coordinator):
step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs),
desc=f'steps',
disable=not coordinator.is_master())
total_loss = 0
for epoch in range(max_epochs):
model.train()
for batch_id, batch in enumerate(train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = outputs.loss
loss = loss / accumulation_steps
loss.backward()
total_loss += loss.item()
if (batch_id + 1) % accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
step_bar.set_postfix({
'epoch': epoch,
'loss': total_loss / batch_size,
'lr': scheduler.get_last_lr()[0]
})
total_loss = 0
step_bar.update()
# evaluate
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
for batch in dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
if coordinator.is_master():
results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size)
return results
if isinstance(test_dataloader, DataLoader):
return evaluate_subset(test_dataloader)
else:
assert len(test_dataloader) == len(eval_splits)
final_results = {}
for split, sub_loader in zip(eval_splits, test_dataloader):
results = evaluate_subset(sub_loader)
final_results.update({f'{k}_{split}': v for k, v in results.items()})
return final_results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
parser.add_argument('--model', type=str, default="bert")
parser.add_argument('--pretrain', type=str, default="bert-base-uncased")
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lr', type=float, default=2.4e-5)
parser.add_argument('--fused_layernorm', type=bool, default=False)
parser.add_argument('--accumulation_steps', type=int, default=8)
parser.add_argument('--warmup_fraction', type=float, default=0.03)
parser.add_argument('--target_f1', type=float, default=None)
args = parser.parse_args()
train(args)

View File

@ -0,0 +1,9 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
--batch_size 2 \
--lr 2.4e-5 \
--fused_layernorm False \
--accumulation_steps 8 \
--warmup_fraction 0.03

View File

@ -17,7 +17,7 @@ class ShardConfig:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_process_group: int = None
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
enable_all_optimization: bool = False