mirror of https://github.com/hpcaitech/ColossalAI
[booster] update bert example, using booster api (#3885)
parent
5e2132dcff
commit
a55fb00c18
@ -0,0 +1,34 @@
|
||||
## Overview
|
||||
|
||||
This directory includes two parts: Using the Booster API fintune Huggingface Bert and AlBert models and benchmarking Bert and AlBert models with different Booster Plugin.
|
||||
|
||||
## Finetune
|
||||
```
|
||||
bash test_ci.sh
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
```
|
||||
bash benchmark.sh
|
||||
```
|
||||
|
||||
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
|
||||
|
||||
## Results
|
||||
|
||||
### Bert
|
||||
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
| ddp | 21.44 GB | 3.0 | 82M |
|
||||
| ddp_fp16 | 16.26 GB | 11.3 | 82M |
|
||||
| gemini | 11.0 GB | 12.9 | 82M |
|
||||
| low_level_zero | 11.29 G | 14.7 | 82M |
|
||||
|
||||
### AlBert
|
||||
| | max cuda mem | throughput(sample/s) | params |
|
||||
| :-----| -----------: | :--------: | :----: |
|
||||
| ddp | OOM | | |
|
||||
| ddp_fp16 | OOM | | |
|
||||
| gemini | 69.39 G | 1.3 | 208M |
|
||||
| low_level_zero | 56.89 G | 1.4 | 208M |
|
@ -0,0 +1,174 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from benchmark_utils import benchmark
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import (
|
||||
AlbertConfig,
|
||||
AlbertForSequenceClassification,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
# ==============================
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 2.4e-5
|
||||
WEIGHT_DECAY = 0.01
|
||||
WARMUP_FRACTION = 0.1
|
||||
SEQ_LEN = 512
|
||||
VOCAB_SIZE = 1000
|
||||
NUM_LABELS = 10
|
||||
DATASET_LEN = 1000
|
||||
|
||||
|
||||
class RandintDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int):
|
||||
|
||||
self._sequence_length = sequence_length
|
||||
self._vocab_size = vocab_size
|
||||
self._n_class = n_class
|
||||
self._dataset_length = dataset_length
|
||||
self._datas = torch.randint(
|
||||
low=0,
|
||||
high=self._vocab_size,
|
||||
size=(self._dataset_length, self._sequence_length,),
|
||||
dtype=torch.long,
|
||||
)
|
||||
self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
return self._dataset_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._datas[idx], self._labels[idx]
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="bert",
|
||||
help="bert or albert",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# ==============================
|
||||
# Launch Distributed Environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# local_batch_size = BATCH_SIZE // coordinator.world_size
|
||||
lr = LEARNING_RATE * coordinator.world_size
|
||||
|
||||
# ==============================
|
||||
# Instantiate Plugin and Booster
|
||||
# ==============================
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# ==============================
|
||||
# Prepare Dataloader
|
||||
# ==============================
|
||||
|
||||
train_dataset = RandintDataset(dataset_length=DATASET_LEN,
|
||||
sequence_length=SEQ_LEN,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
n_class=NUM_LABELS)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
# ====================================
|
||||
# Prepare model, optimizer
|
||||
# ====================================
|
||||
# bert pretrained model
|
||||
|
||||
if args.model_type == "bert":
|
||||
cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
|
||||
model = BertForSequenceClassification(cfg)
|
||||
elif args.model_type == "albert":
|
||||
cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
|
||||
model = AlbertForSequenceClassification(cfg)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": WEIGHT_DECAY,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
|
||||
|
||||
# lr scheduler
|
||||
total_steps = len(train_dataloader) * NUM_EPOCHS
|
||||
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps,
|
||||
)
|
||||
|
||||
# criterion
|
||||
criterion = lambda inputs: inputs[0]
|
||||
|
||||
# ==============================
|
||||
# Boost with ColossalAI
|
||||
# ==============================
|
||||
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
# ==============================
|
||||
# Benchmark model
|
||||
# ==============================
|
||||
|
||||
results = benchmark(model,
|
||||
booster,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
train_dataloader,
|
||||
criterion=criterion,
|
||||
epoch_num=NUM_EPOCHS)
|
||||
|
||||
coordinator.print_on_master(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
||||
torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "bert"
|
||||
torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "albert"
|
||||
done
|
@ -0,0 +1,146 @@
|
||||
import inspect
|
||||
from logging import getLogger
|
||||
from time import time
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
logger = getLogger("colossalai-booster-benchmark")
|
||||
_INVALID = float("nan")
|
||||
|
||||
|
||||
def format_num(num: int, bytes=False):
|
||||
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
||||
factor = 1024 if bytes else 1000
|
||||
suffix = "B" if bytes else ""
|
||||
for unit in ["", " K", " M", " G", " T", " P"]:
|
||||
if num < factor:
|
||||
return f"{num:.2f}{unit}{suffix}"
|
||||
num /= factor
|
||||
|
||||
|
||||
def _is_valid(val):
|
||||
return val == val
|
||||
|
||||
|
||||
def get_call_arg_names(module_or_fn):
|
||||
if isinstance(module_or_fn, torch.nn.Module):
|
||||
return inspect.getfullargspec(module_or_fn.forward)[0][1:]
|
||||
return inspect.getfullargspec(module_or_fn)[0]
|
||||
|
||||
|
||||
def measure_params(model):
|
||||
num_params = _INVALID
|
||||
|
||||
try:
|
||||
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
except AttributeError as e:
|
||||
logger.error(f"Unable to measure model params due to error: {e}")
|
||||
|
||||
return num_params
|
||||
|
||||
|
||||
def warm_up(
|
||||
model,
|
||||
booster,
|
||||
dataloader,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
num_runs=10,
|
||||
):
|
||||
for i, data in enumerate(dataloader):
|
||||
if i > num_runs:
|
||||
break
|
||||
inputs, labels = data[0].cuda(), data[1].cuda()
|
||||
outputs = model(inputs, labels=labels)
|
||||
loss = criterion(outputs)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
def fmt(d: dict):
|
||||
return yaml.dump(d)
|
||||
|
||||
|
||||
def benchmark(
|
||||
model: torch.nn.Module,
|
||||
booster: Booster,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: LRScheduler,
|
||||
dataloader: DataLoader,
|
||||
criterion: Callable = None,
|
||||
warm_up_fn=warm_up,
|
||||
epoch_num: int = 3,
|
||||
batch_size: int = 32,
|
||||
warm_up_steps: int = 3,
|
||||
):
|
||||
results = {}
|
||||
model_device = torch.cuda.current_device()
|
||||
|
||||
# Warm up
|
||||
warm_up_fn(
|
||||
model,
|
||||
booster,
|
||||
dataloader,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
num_runs=warm_up_steps,
|
||||
)
|
||||
# Measure params
|
||||
params = measure_params(model)
|
||||
if _is_valid(params):
|
||||
results["params"] = format_num(params)
|
||||
logger.info(f"Model parameters: {params} ({format_num(params)})")
|
||||
|
||||
# Measure Allocated Memory and Throughput
|
||||
memory = {}
|
||||
throughput = {}
|
||||
torch.cuda.reset_peak_memory_stats(device=model_device)
|
||||
pre_mem = torch.cuda.memory_allocated(device=model_device)
|
||||
|
||||
start_time = time()
|
||||
|
||||
for epoch in range(epoch_num):
|
||||
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}/{epoch_num}]',
|
||||
disable=not DistCoordinator().is_master()) as pbar:
|
||||
for data in pbar:
|
||||
inputs, labels = data[0].cuda(), data[1].cuda()
|
||||
outputs = model(inputs, labels=labels)
|
||||
loss = criterion(outputs)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
end_time = time()
|
||||
|
||||
all_sample = epoch_num * len(dataloader)
|
||||
|
||||
post_mem = torch.cuda.memory_allocated(device=model_device)
|
||||
max_mem = torch.cuda.max_memory_allocated(device=model_device)
|
||||
|
||||
memory[f"batch_size_{batch_size}"] = {
|
||||
"cuda_pre_training_bytes": format_num(pre_mem, bytes=True),
|
||||
"cuda_max_training_bytes": format_num(max_mem, bytes=True),
|
||||
"cuda_post_training_bytes": format_num(post_mem, bytes=True),
|
||||
}
|
||||
logger.info(fmt({f"Memory results (batch_size={batch_size})": memory[f"batch_size_{batch_size}"]}))
|
||||
|
||||
throughput[f"batch_size_{batch_size}"] = {"throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time))}
|
||||
logger.info(fmt({f"Throughput results (batch_size={batch_size})": throughput[f"batch_size_{batch_size}"]}))
|
||||
|
||||
results["throughput"] = throughput
|
||||
results["memory"] = memory
|
||||
|
||||
return results
|
@ -0,0 +1,127 @@
|
||||
import datasets
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||||
|
||||
|
||||
class GLUEDataBuilder:
|
||||
|
||||
task_text_field_map = {
|
||||
"cola": ["sentence"],
|
||||
"sst2": ["sentence"],
|
||||
"mrpc": ["sentence1", "sentence2"],
|
||||
"qqp": ["question1", "question2"],
|
||||
"stsb": ["sentence1", "sentence2"],
|
||||
"mnli": ["premise", "hypothesis"],
|
||||
"qnli": ["question", "sentence"],
|
||||
"rte": ["sentence1", "sentence2"],
|
||||
"wnli": ["sentence1", "sentence2"],
|
||||
"ax": ["premise", "hypothesis"],
|
||||
}
|
||||
|
||||
glue_task_num_labels = {
|
||||
"cola": 2,
|
||||
"sst2": 2,
|
||||
"mrpc": 2,
|
||||
"qqp": 2,
|
||||
"stsb": 1,
|
||||
"mnli": 3,
|
||||
"qnli": 2,
|
||||
"rte": 2,
|
||||
"wnli": 2,
|
||||
"ax": 3,
|
||||
}
|
||||
|
||||
loader_columns = [
|
||||
"datasets_idx",
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"start_positions",
|
||||
"end_positions",
|
||||
"labels",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
plugin: DPPluginBase,
|
||||
task_name: str = "mrpc",
|
||||
max_seq_length: int = 128,
|
||||
train_batch_size: int = 32,
|
||||
eval_batch_size: int = 32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.task_name = task_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.train_batch_size = train_batch_size
|
||||
self.eval_batch_size = eval_batch_size
|
||||
self.plugin = plugin
|
||||
|
||||
self.text_fields = self.task_text_field_map[task_name]
|
||||
self.num_labels = self.glue_task_num_labels[task_name]
|
||||
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.dataset = datasets.load_dataset("glue", self.task_name)
|
||||
|
||||
for split in self.dataset.keys():
|
||||
self.dataset[split] = self.dataset[split].map(
|
||||
self.convert_to_features,
|
||||
batched=True,
|
||||
remove_columns=["label"],
|
||||
)
|
||||
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
|
||||
self.dataset[split].set_format(type="torch", columns=self.columns)
|
||||
|
||||
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
|
||||
|
||||
def prepare_data(self):
|
||||
datasets.load_dataset("glue", self.task_name)
|
||||
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.plugin.prepare_dataloader(self.dataset["train"],
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
def test_dataloader(self):
|
||||
if len(self.eval_splits) == 1:
|
||||
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
||||
elif len(self.eval_splits) > 1:
|
||||
return [
|
||||
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||||
for x in self.eval_splits
|
||||
]
|
||||
|
||||
def convert_to_features(self, example_batch):
|
||||
|
||||
# Either encode single sentence or sentence pairs
|
||||
if len(self.text_fields) > 1:
|
||||
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
|
||||
else:
|
||||
texts_or_text_pairs = example_batch[self.text_fields[0]]
|
||||
|
||||
# Tokenize the text/text pairs
|
||||
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
|
||||
max_length=self.max_seq_length,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
|
||||
# Rename label to labels to make it easier to pass to model forward
|
||||
features["labels"] = example_batch["label"]
|
||||
|
||||
return features
|
@ -0,0 +1,220 @@
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
|
||||
import evaluate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from data import GLUEDataBuilder
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AlbertForSequenceClassification,
|
||||
AutoConfig,
|
||||
BertForSequenceClassification,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
# ==============================
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 2.4e-5
|
||||
WEIGHT_DECAY = 0.01
|
||||
WARMUP_FRACTION = 0.1
|
||||
|
||||
|
||||
def move_to_cuda(batch):
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
|
||||
eval_splits: List[str], coordinator: DistCoordinator):
|
||||
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
||||
model.eval()
|
||||
|
||||
def evaluate_subset(dataloader: DataLoader):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
for batch in dataloader:
|
||||
batch = move_to_cuda(batch)
|
||||
outputs = model(**batch)
|
||||
val_loss, logits = outputs[:2]
|
||||
accum_loss.add_(val_loss)
|
||||
|
||||
if num_labels > 1:
|
||||
preds = torch.argmax(logits, axis=1)
|
||||
elif num_labels == 1:
|
||||
preds = logits.squeeze()
|
||||
|
||||
labels = batch["labels"]
|
||||
|
||||
metric.add_batch(predictions=preds, references=labels)
|
||||
|
||||
results = metric.compute()
|
||||
dist.all_reduce(accum_loss.div_(len(dataloader)))
|
||||
if coordinator.is_master():
|
||||
results['loss'] = accum_loss.item() / coordinator.world_size
|
||||
return results
|
||||
|
||||
if isinstance(test_dataloader, DataLoader):
|
||||
return evaluate_subset(test_dataloader)
|
||||
else:
|
||||
assert len(test_dataloader) == len(eval_splits)
|
||||
final_results = {}
|
||||
for split, sub_loader in zip(eval_splits, test_dataloader):
|
||||
results = evaluate_subset(sub_loader)
|
||||
final_results.update({f'{k}_{split}': v for k, v in results.items()})
|
||||
return final_results
|
||||
|
||||
|
||||
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
|
||||
booster: Booster, coordinator: DistCoordinator):
|
||||
model.train()
|
||||
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
|
||||
for batch in pbar:
|
||||
# Forward pass
|
||||
batch = move_to_cuda(batch)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
|
||||
# Backward and optimize
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Print log info
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
|
||||
|
||||
def main():
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
type=str,
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
|
||||
help="plugin to use")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="bert",
|
||||
help="bert or albert",
|
||||
)
|
||||
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == 'bert':
|
||||
model_name = "bert-base-uncased"
|
||||
elif args.model_type == 'albert':
|
||||
model_name = "albert-xxlarge-v2"
|
||||
else:
|
||||
raise RuntimeError
|
||||
# ==============================
|
||||
# Launch Distributed Environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# local_batch_size = BATCH_SIZE // coordinator.world_size
|
||||
lr = LEARNING_RATE * coordinator.world_size
|
||||
|
||||
# ==============================
|
||||
# Instantiate Plugin and Booster
|
||||
# ==============================
|
||||
booster_kwargs = {}
|
||||
if args.plugin == 'torch_ddp_fp16':
|
||||
booster_kwargs['mixed_precision'] = 'fp16'
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
# ==============================
|
||||
# Prepare Dataloader
|
||||
# ==============================
|
||||
data_builder = GLUEDataBuilder(model_name,
|
||||
plugin,
|
||||
args.task,
|
||||
train_batch_size=BATCH_SIZE,
|
||||
eval_batch_size=BATCH_SIZE)
|
||||
train_dataloader = data_builder.train_dataloader()
|
||||
test_dataloader = data_builder.test_dataloader()
|
||||
|
||||
# ====================================
|
||||
# Prepare model, optimizer
|
||||
# ====================================
|
||||
# bert pretrained model
|
||||
|
||||
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
|
||||
if model_name == "bert-base-uncased":
|
||||
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
|
||||
elif model_name == "albert-xxlarge-v2":
|
||||
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": WEIGHT_DECAY,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
|
||||
|
||||
# lr scheduler
|
||||
total_steps = len(train_dataloader) * NUM_EPOCHS
|
||||
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_steps,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Boost with ColossalAI
|
||||
# ==============================
|
||||
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
# ==============================
|
||||
# Train model
|
||||
# ==============================
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
|
||||
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
|
||||
coordinator)
|
||||
|
||||
if coordinator.is_master():
|
||||
print(results)
|
||||
if args.target_f1 is not None and 'f1' in results:
|
||||
assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,9 @@
|
||||
colossalai
|
||||
evaluate
|
||||
datasets
|
||||
torch
|
||||
tqdm
|
||||
transformers
|
||||
scipy
|
||||
scikit-learn
|
||||
ptflops
|
@ -1,22 +0,0 @@
|
||||
set -x
|
||||
# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
|
||||
export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
|
||||
# bert | albert
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"bert"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
|
||||
mkdir -p gemini_logs
|
||||
|
||||
env CUDA_LAUNCH_BLOCKING=1 torchrun --standalone --nproc_per_node=${GPUNUM} ./train_bert_demo.py \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log
|
@ -1,2 +1,8 @@
|
||||
set -x
|
||||
env GPUNUM=1 bash run_gemini.sh
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
|
||||
done
|
||||
|
@ -1,331 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import AlbertConfig, AlbertForSequenceClassification, BertConfig, BertForSequenceClassification
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
if enable_flag:
|
||||
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=True)
|
||||
else:
|
||||
|
||||
class DummyProfiler:
|
||||
|
||||
def __init__(self):
|
||||
self.step_number = 0
|
||||
|
||||
def step(self):
|
||||
self.step_number += 1
|
||||
|
||||
return nullcontext(DummyProfiler())
|
||||
|
||||
|
||||
def get_time_stamp():
|
||||
import time
|
||||
cur_time = time.strftime("%d-%H:%M", time.localtime())
|
||||
return cur_time
|
||||
|
||||
|
||||
def get_bert_data(batch_size: int, sequence_length: int, vacob_size: int, n_class: int, device: torch.device):
|
||||
input = torch.randint(
|
||||
low=0,
|
||||
high=vacob_size,
|
||||
size=(batch_size, sequence_length),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
label = torch.randint(low=0, high=n_class, size=(batch_size,), device=device, dtype=torch.long)
|
||||
return input, label
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--distplan",
|
||||
type=str,
|
||||
default='CAI_Gemini',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="batch size per DP group of training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="bert",
|
||||
help="bert or albert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_step",
|
||||
type=int,
|
||||
default=10,
|
||||
help="training iterations for test",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
SEQ_LEN = 512
|
||||
VOCAB_SIZE = 1000
|
||||
NUM_LABELS = 10
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
for p in module.parameters(recurse=False):
|
||||
total_numel += p.numel()
|
||||
return total_numel
|
||||
|
||||
|
||||
def model_builder(args):
|
||||
if args.model_type == "bert":
|
||||
cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
|
||||
return BertForSequenceClassification(cfg)
|
||||
elif args.model_type == "albert":
|
||||
cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
|
||||
return AlbertForSequenceClassification(cfg)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def model_size_formatter(numel: int) -> str:
|
||||
GB_SIZE = 10**9
|
||||
MB_SIZE = 10**6
|
||||
KB_SIZE = 10**3
|
||||
if numel >= GB_SIZE:
|
||||
return f'{numel / GB_SIZE:.1f}B'
|
||||
elif numel >= MB_SIZE:
|
||||
return f'{numel / MB_SIZE:.1f}M'
|
||||
elif numel >= KB_SIZE:
|
||||
return f'{numel / KB_SIZE:.1f}K'
|
||||
else:
|
||||
return str(numel)
|
||||
|
||||
|
||||
def set_cpu_maximum_parallelism():
|
||||
conf_str = torch.__config__.parallel_info()
|
||||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||
max_concurrency = inter_str.split('\n')[0]
|
||||
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.2.0
|
||||
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
|
||||
|
||||
set_cpu_maximum_parallelism()
|
||||
args = parse_args()
|
||||
|
||||
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
# batch size per DP degree
|
||||
BATCH_SIZE = args.batch_size
|
||||
|
||||
NUM_STEPS = args.train_step
|
||||
|
||||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||
PROF_FLAG = False # The flag of profiling, False by default
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(f" {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
# all param must use the same process group.
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# build a base-bert model
|
||||
with ColoInitContext(device=get_current_device(), dtype=torch.half):
|
||||
model = model_builder(args)
|
||||
# model = BertForSequenceClassification(BertConfig(vocal_size = VOCAB_SIZE))
|
||||
|
||||
# asign running configurations
|
||||
gemini_config = None
|
||||
if args.distplan.startswith("CAI_ZeRO"):
|
||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
gemini_config = dict(strict_ddp_mode=True,
|
||||
device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.hidden_size,
|
||||
search_range_mb=128)
|
||||
optim_config = dict(gpu_margin_mem_ratio=0.)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# build a highly optimized gpu/cpu optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
elif args.distplan == "CAI_ZeRO2":
|
||||
zero_stage = 2
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
zero_stage = 3
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# wrap your model and optimizer
|
||||
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
||||
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
elif args.distplan.startswith("Pytorch"):
|
||||
model = model_builder(args).cuda()
|
||||
model = DDP(model)
|
||||
if args.distplan.endswith("DDP"):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
elif args.distplan.endswith("ZeRO"):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||||
# = batch_per_DP_group * numel * seq_len * 8
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
tflops_list = []
|
||||
|
||||
def train_step():
|
||||
# we just use randomly generated data here
|
||||
input_ids, labels = get_bert_data(BATCH_SIZE,
|
||||
SEQ_LEN,
|
||||
VOCAB_SIZE,
|
||||
NUM_LABELS,
|
||||
device=torch.cuda.current_device())
|
||||
optimizer.zero_grad()
|
||||
|
||||
start = time()
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time()
|
||||
fwd_time = fwd_end - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
|
||||
if args.distplan.startswith("CAI"):
|
||||
optimizer.backward(loss)
|
||||
elif args.distplan.startswith("Pytorch"):
|
||||
loss.backward()
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time()
|
||||
bwd_time = bwd_end - fwd_end
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
|
||||
optimizer.step()
|
||||
torch.cuda.synchronize()
|
||||
optim_time = time() - bwd_end
|
||||
step_time = time() - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
|
||||
step_tflops = get_tflops_func(step_time)
|
||||
logger.info(
|
||||
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||||
ranks=[0],
|
||||
)
|
||||
if n >= WARMUP_STEPS:
|
||||
tflops_list.append(step_tflops)
|
||||
|
||||
demo_profiler = get_profile_context(PROF_FLAG,
|
||||
WARMUP_STEPS,
|
||||
NUM_STEPS - WARMUP_STEPS,
|
||||
save_dir=f"profile/{get_time_stamp()}-demo")
|
||||
|
||||
with demo_profiler as prof:
|
||||
for n in range(NUM_STEPS):
|
||||
train_step()
|
||||
prof.step()
|
||||
|
||||
tflops_list.sort()
|
||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue