ColossalAI/colossalai/cli/benchmark/benchmark.py

95 lines
3.6 KiB
Python

import colossalai
import click
import torch.multiprocessing as mp
from functools import partial
from typing import List, Dict
from colossalai.context import Config
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port, MultiTimer
from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data
from .models import MLP
def run_benchmark(args: Config) -> None:
"""
Run benchmarking with torch.multiprocessing.
"""
# sanity checks
if args.gpus is None:
click.echo("Error: --num_gpus is not given")
exit()
click.echo("=== Benchmarking Parameters ===")
for k, v in args.items():
click.echo(f'{k}: {v}')
click.echo('')
config_list = find_all_configs(args.gpus)
avail_ports = [free_port() for _ in range(len(config_list))]
run_func = partial(run_dist_profiling,
world_size=args.gpus,
port_list=avail_ports,
config_list=config_list,
hyperparams=args)
mp.spawn(run_func, nprocs=args.gpus)
def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
hyperparams: Config) -> None:
"""
A function executed for profiling, this function should be spawn by torch.multiprocessing.
Args:
rank (int): rank of the process
world_size (int): the number of processes
port_list (List[int]): a list of free ports for initializing distributed networks
config_list (List[Dict]): a list of configuration
hyperparams (Config): the hyperparameters given by the user
"""
# disable logging for clean output
disable_existing_loggers()
logger = get_dist_logger()
logger.set_level('WARNING')
for config, port in zip(config_list, port_list):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
timer = MultiTimer()
if hyperparams.model == 'mlp':
model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
else:
if gpc.get_global_rank() == 0:
click.echo("Error: Invalid argument for --model")
exit()
data_func = partial(get_batch_data,
dim=hyperparams.dimension,
batch_size=hyperparams.batch_size,
seq_length=hyperparams.seq_len,
mode=config.parallel.tensor.mode)
fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
warmup_steps=hyperparams.warmup_steps,
profile_steps=hyperparams.profile_steps,
data_func=data_func,
timer=timer)
gpc.destroy()
reset_seeds()
if gpc.get_global_rank() == 0:
config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
click.echo(f"=== {config_str} ===")
click.echo(f"Average forward time: {fwd_time}")
click.echo(f"Average backward time: {bwd_time}")
click.echo(f"Max allocated GPU memory: {max_allocated}")
click.echo(f"Max cached GPU memory: {max_cached}\n")