import colossalai import click import torch.multiprocessing as mp from functools import partial from typing import List, Dict from colossalai.context import Config from colossalai.context.random import reset_seeds from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.utils import free_port, MultiTimer from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data from .models import MLP def run_benchmark(args: Config) -> None: """ Run benchmarking with torch.multiprocessing. """ # sanity checks if args.gpus is None: click.echo("Error: --num_gpus is not given") exit() click.echo("=== Benchmarking Parameters ===") for k, v in args.items(): click.echo(f'{k}: {v}') click.echo('') config_list = find_all_configs(args.gpus) avail_ports = [free_port() for _ in range(len(config_list))] run_func = partial(run_dist_profiling, world_size=args.gpus, port_list=avail_ports, config_list=config_list, hyperparams=args) mp.spawn(run_func, nprocs=args.gpus) def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], hyperparams: Config) -> None: """ A function executed for profiling, this function should be spawn by torch.multiprocessing. Args: rank (int): rank of the process world_size (int): the number of processes port_list (List[int]): a list of free ports for initializing distributed networks config_list (List[Dict]): a list of configuration hyperparams (Config): the hyperparameters given by the user """ # disable logging for clean output disable_existing_loggers() logger = get_dist_logger() logger.set_level('WARNING') for config, port in zip(config_list, port_list): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') timer = MultiTimer() if hyperparams.model == 'mlp': model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) else: if gpc.get_global_rank() == 0: click.echo("Error: Invalid argument for --model") exit() data_func = partial(get_batch_data, dim=hyperparams.dimension, batch_size=hyperparams.batch_size, seq_length=hyperparams.seq_len, mode=config.parallel.tensor.mode) fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, warmup_steps=hyperparams.warmup_steps, profile_steps=hyperparams.profile_steps, data_func=data_func, timer=timer) gpc.destroy() reset_seeds() if gpc.get_global_rank() == 0: config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) click.echo(f"=== {config_str} ===") click.echo(f"Average forward time: {fwd_time}") click.echo(f"Average backward time: {bwd_time}") click.echo(f"Max allocated GPU memory: {max_allocated}") click.echo(f"Max cached GPU memory: {max_cached}\n")