ColossalAI/examples/images/vit/vit_benchmark.py

153 lines
5.2 KiB
Python

import time
import torch
import transformers
from args import parse_benchmark_args
from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
factor = 1024 if bytes else 1000
suffix = "B" if bytes else ""
for unit in ["", " K", " M", " G", " T", " P"]:
if num < factor:
return f"{num:.2f}{unit}{suffix}"
num /= factor
def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(
batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float
)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return dict(pixel_values=pixel_values, labels=labels)
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print(f"Limiting GPU memory usage to {size_in_GB} GB")
def main():
args = parse_benchmark_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
# Manage loggers
disable_existing_loggers()
logger = get_dist_logger()
if coordinator.is_master():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Whether to set limit on memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
model = ViTForImageClassification(config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
elif args.plugin == "gemini":
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == "hybrid_parallel":
plugin = HybridParallelPlugin(
tp_size=2,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
precision="fp16",
initial_scale=1,
)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
# Set criterion (loss function)
def criterion(outputs, inputs):
return outputs.loss
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion)
# Start training.
logger.info(f"Start testing", ranks=[0])
torch.cuda.synchronize()
model.train()
start_time = time.time()
with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar:
for _ in pbar:
optimizer.zero_grad()
batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224)
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
# run pipeline forward backward
batch = iter([batch])
outputs = booster.execute_pipeline(
batch, model, criterion, optimizer, return_loss=True, return_outputs=True
)
else:
outputs = model(**batch)
loss = criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
optimizer.step()
torch.cuda.synchronize()
# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(
f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0],
)
torch.cuda.empty_cache()
if __name__ == "__main__":
main()