ColossalAI/examples/language/bert/benchmark.py

173 lines
5.2 KiB
Python

import argparse
import torch
from benchmark_utils import benchmark
from torch.utils.data import DataLoader, Dataset
from transformers import (
AlbertConfig,
AlbertForSequenceClassification,
BertConfig,
BertForSequenceClassification,
get_linear_schedule_with_warmup,
)
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
SEQ_LEN = 512
VOCAB_SIZE = 1000
NUM_LABELS = 10
DATASET_LEN = 1000
class RandintDataset(Dataset):
def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int):
self._sequence_length = sequence_length
self._vocab_size = vocab_size
self._n_class = n_class
self._dataset_length = dataset_length
self._datas = torch.randint(
low=0,
high=self._vocab_size,
size=(
self._dataset_length,
self._sequence_length,
),
dtype=torch.long,
)
self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long)
def __len__(self):
return self._dataset_length
def __getitem__(self, idx):
return self._datas[idx], self._labels[idx]
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
help="plugin to use",
)
parser.add_argument(
"--model_type",
type=str,
default="bert",
help="bert or albert",
)
args = parser.parse_args()
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
lr = LEARNING_RATE * coordinator.world_size
# ==============================
# Instantiate Plugin and Booster
# ==============================
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
elif args.plugin == "gemini":
plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs)
# ==============================
# Prepare Dataloader
# ==============================
train_dataset = RandintDataset(
dataset_length=DATASET_LEN, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE, n_class=NUM_LABELS
)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
# ====================================
# Prepare model, optimizer
# ====================================
# bert pretrained model
if args.model_type == "bert":
cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
model = BertForSequenceClassification(cfg)
elif args.model_type == "albert":
cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
model = AlbertForSequenceClassification(cfg)
else:
raise RuntimeError
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": WEIGHT_DECAY,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
# lr scheduler
total_steps = len(train_dataloader) * NUM_EPOCHS
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)
# criterion
criterion = lambda inputs: inputs[0]
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
# ==============================
# Benchmark model
# ==============================
results = benchmark(
model, booster, optimizer, lr_scheduler, train_dataloader, criterion=criterion, epoch_num=NUM_EPOCHS
)
coordinator.print_on_master(results)
if __name__ == "__main__":
main()