diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py index a8afb98e3..3653f1ffa 100644 --- a/colossalai/cli/benchmark/__init__.py +++ b/colossalai/cli/benchmark/__init__.py @@ -1,2 +1,28 @@ +from random import choices +import click + from .utils import * -from .run import * +from .benchmark import run_benchmark +from colossalai.context import Config + +__all__ = ['benchmark'] + + +@click.command() +@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.") +@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.") +@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.") +@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.") +@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.") +@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.") +@click.option("-l", "--layers", type=int, default=2) +@click.option("-m", + "--model", + type=click.Choice(['mlp'], case_sensitive=False), + default='mlp', + help="Select the model to benchmark, currently only supports MLP") +def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int, + layers: int, model: str): + args_dict = locals() + args = Config(args_dict) + run_benchmark(args) diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py new file mode 100644 index 000000000..5bf09aa4e --- /dev/null +++ b/colossalai/cli/benchmark/benchmark.py @@ -0,0 +1,94 @@ +import colossalai +import click +import torch.multiprocessing as mp + +from functools import partial +from typing import List, Dict + +from colossalai.context import Config +from colossalai.context.random import reset_seeds +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.utils import free_port, MultiTimer +from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data +from .models import MLP + + +def run_benchmark(args: Config) -> None: + """ + Run benchmarking with torch.multiprocessing. + """ + + # sanity checks + if args.gpus is None: + click.echo("Error: --num_gpus is not given") + exit() + + click.echo("=== Benchmarking Parameters ===") + for k, v in args.items(): + click.echo(f'{k}: {v}') + click.echo('') + + config_list = find_all_configs(args.gpus) + + avail_ports = [free_port() for _ in range(len(config_list))] + run_func = partial(run_dist_profiling, + world_size=args.gpus, + port_list=avail_ports, + config_list=config_list, + hyperparams=args) + mp.spawn(run_func, nprocs=args.gpus) + + +def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict], + hyperparams: Config) -> None: + """ + A function executed for profiling, this function should be spawn by torch.multiprocessing. + + Args: + rank (int): rank of the process + world_size (int): the number of processes + port_list (List[int]): a list of free ports for initializing distributed networks + config_list (List[Dict]): a list of configuration + hyperparams (Config): the hyperparameters given by the user + + """ + + # disable logging for clean output + disable_existing_loggers() + logger = get_dist_logger() + logger.set_level('WARNING') + + for config, port in zip(config_list, port_list): + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + timer = MultiTimer() + + if hyperparams.model == 'mlp': + model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers) + else: + if gpc.get_global_rank() == 0: + click.echo("Error: Invalid argument for --model") + exit() + + data_func = partial(get_batch_data, + dim=hyperparams.dimension, + batch_size=hyperparams.batch_size, + seq_length=hyperparams.seq_len, + mode=config.parallel.tensor.mode) + + fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model, + warmup_steps=hyperparams.warmup_steps, + profile_steps=hyperparams.profile_steps, + data_func=data_func, + timer=timer) + + gpc.destroy() + reset_seeds() + + if gpc.get_global_rank() == 0: + config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()]) + click.echo(f"=== {config_str} ===") + click.echo(f"Average forward time: {fwd_time}") + click.echo(f"Average backward time: {bwd_time}") + click.echo(f"Max allocated GPU memory: {max_allocated}") + click.echo(f"Max cached GPU memory: {max_cached}\n") diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py new file mode 100644 index 000000000..38ea54188 --- /dev/null +++ b/colossalai/cli/benchmark/models.py @@ -0,0 +1,17 @@ +import torch +import colossalai.nn as col_nn + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(col_nn.Linear(dim, dim)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x diff --git a/colossalai/cli/benchmark/run.py b/colossalai/cli/benchmark/run.py deleted file mode 100644 index 825e9d8d1..000000000 --- a/colossalai/cli/benchmark/run.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import inspect -import os -import subprocess -import sys - -from colossalai.initialize import launch_from_torch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import print_rank_0 -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import free_port -from colossalai.cli.benchmark import build_args_parser, build_configs, \ - build_input_tensor, profile_1d, profile_2d, profile_2p5d, profile_3d, \ - BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES - - -def launch(args=None): - train_script = inspect.getfile(inspect.currentframe()) - assert args is not None, "args should not be None" - env = os.environ.copy() - if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count(): - nproc_per_node = torch.cuda.device_count() - else: - nproc_per_node = args.num_gpus - - train_args = [f"--num_gpus={nproc_per_node}"] - if args.bs != BATCH_SIZE: - train_args.append(f"--bs={args.bs}") - if args.hid_dim != HIDDEN_DIM: - train_args.append(f"--hid_dim={args.hid_dim}") - if args.num_steps != ITER_TIMES: - train_args.append(f"--num_steps={args.num_steps}") - if args.seq_len != SEQ_LENGTH: - train_args.append(f"--seq_len={args.seq_len}") - - master_port = free_port() - if torch.__version__ <= "1.09": - cmd = [sys.executable, "-u", "-m", - "torch.distributed.launch", - f"--nproc_per_node={nproc_per_node}", - f"--master_port={master_port}"] + [train_script] + train_args - else: - cmd = ["torchrun", - f"--nproc_per_node={nproc_per_node}", - f"--master_port={master_port}"] + [train_script] + train_args - - result = subprocess.Popen(cmd, env=env) - result.wait() - if result.returncode > 0: - sys.exit(result.returncode) - -def main(): - parser = build_args_parser() - args = parser.parse_args() - disable_existing_loggers() - logger = get_dist_logger() - launch_from_torch(config={}, verbose=False) - input_tensor = build_input_tensor(args) - config_dict = build_configs(args) - if len(config_dict) == 0: - print_rank_0(f"WARNING: We need at least two devices to profile TP strategies performance.") - gpc.destroy() - return - for parallel_mode, config in config_dict.items(): - if parallel_mode == "1d": - result_1d = profile_1d(input_tensor, config, args) - print_rank_0(f"INFO: Totoal time cost in 1D TP is {result_1d}.") - if parallel_mode == "2d": - result_2d = profile_2d(input_tensor, config, args) - print_rank_0(f"INFO: Totoal time cost in 2D TP is {result_2d}.") - if parallel_mode == "2p5d": - result_2p5d = profile_2p5d(input_tensor, config, args) - print_rank_0(f"INFO: Totoal time cost in 2P5D TP is {result_2p5d}.") - if parallel_mode == "3d": - result_3d = profile_3d(input_tensor, config, args) - print_rank_0(f"INFO: Totoal time cost in 3D TP is {result_3d}.") - if "2d" not in config_dict: - print_rank_0(f"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices.") - if "2p5d" not in config_dict: - print_rank_0(f"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices.") - print_rank_0(f"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices.") - gpc.destroy() - -if __name__=="__main__": - main() diff --git a/colossalai/cli/benchmark/simple_model.py b/colossalai/cli/benchmark/simple_model.py deleted file mode 100644 index 6ae6973bf..000000000 --- a/colossalai/cli/benchmark/simple_model.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import colossalai -import colossalai.nn as col_nn - -class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): - super().__init__() - intermediate_dim = dim * 4 - self.dense_1 = col_nn.Linear(dim, intermediate_dim) - self.activation = torch.nn.GELU() - self.dense_2 = col_nn.Linear(intermediate_dim, dim) - self.dropout = col_nn.Dropout(0.1) - - def forward(self, x): - x = self.dense_1(x) - x = self.activation(x) - x = self.dense_2(x) - x = self.dropout(x) - return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py index de96a81ac..0317f2a41 100644 --- a/colossalai/cli/benchmark/utils.py +++ b/colossalai/cli/benchmark/utils.py @@ -1,146 +1,154 @@ +import math +import time +from grpc import Call import torch -from .simple_model import MLP -from colossalai.utils import Timer, synchronize + +from colossalai.utils import MultiTimer from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode -from argparse import ArgumentParser +from colossalai.context import ParallelMode, Config +from typing import List, Dict, Tuple, Callable -BATCH_SIZE = 8 -SEQ_LENGTH = 120 -HIDDEN_DIM = 1024 -ITER_TIMES = 2000 -def build_args_parser() -> ArgumentParser: - """Helper function parsing the command line options.""" +def get_time_stamp() -> int: + """ + Return the time stamp for profiling. - parser = ArgumentParser(description="colossal benchmark") + Returns: + time_stamp (int): the time given by time.time() + """ - parser.add_argument("--num_gpus", - type=int, - default=-1, - help="Total number of devices to use.") - parser.add_argument("--bs", - type=int, - default=BATCH_SIZE, - help="Batch size of the input tensor.") - parser.add_argument("--seq_len", - type=int, - default=SEQ_LENGTH, - help="Sequence length of the input tensor.") - parser.add_argument("--hid_dim", - type=int, - default=HIDDEN_DIM, - help="Hidden dimension of the input tensor.") - parser.add_argument("--num_steps", - type=int, - default=ITER_TIMES, - help="The number of iteration times.") - return parser + torch.cuda.synchronize() + time_stamp = time.time() + return time_stamp -def build_input_tensor(args): - return torch.rand(args.bs, args.seq_len, args.hid_dim) -def build_configs_helper(device_cnt: int): - config_dict = {} +def get_memory_states() -> Tuple[float]: + """ + Return the memory statistics. - if device_cnt < 2: - return config_dict + Returns: + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory + """ + + max_allocated = torch.cuda.max_memory_allocated() / (1024**3) + max_cached = torch.cuda.max_memory_reserved() / (1024**3) + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + return max_allocated, max_cached + + +def find_all_configs(device_cnt: int) -> List[Dict]: + """ + Find all possible configurations for tensor parallelism + + Args: + device_cnt (int): the number of devices + + Returns: + config_list (List[Dict]): a list of configurations + """ + + def _is_square(num): + return math.floor(math.sqrt(num))**2 == num + + def _is_cube(num): + return math.floor(num**(1. / 3.))**3 == num + + config_list = [] + + # add non-parallel config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) + config_list.append(config) + + # add 1D config + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) + config_list.append(config) + + # add 1D config only if device_cnt is a square + if _is_square(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) + config_list.append(config) + + # check for 2.5D + # iterate over depth + for depth in range(1, device_cnt): + if device_cnt % depth == 0 and _is_square(device_cnt // depth): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) + config_list.append(config) + + # check for 3D if device_cnt is a cube + if _is_cube(device_cnt): + config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) + config_list.append(config) + + config_list = [Config(cfg) for cfg in config_list] + return config_list + + +def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, + timer: MultiTimer) -> Tuple[float]: + """ + Profile the forward and backward of a model + + Args: + model (torch.nn.Module): a PyTorch model + warmup_steps (int): the number of steps for warmup + profile_steps (int): the number of steps for profiling + data_func (Callable): a function to generate random data + timer (colossalai.utils.Multitimer): a timer instance for time recording - if device_cnt < 4: - config_dict["1d"] = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) - elif device_cnt < 8: - config_dict["1d"] = dict(parallel=dict(tensor=dict(size=4, mode='1d'))) - config_dict["2d"] = dict(parallel=dict(tensor=dict(size=4, mode='2d'))) - else: - config_dict["1d"] = dict(parallel=dict(tensor=dict(size=8, mode='1d'))) - config_dict["2d"] = dict(parallel=dict(data=2, tensor=dict(size=4, mode='2d'))) - config_dict["2p5d"] = dict(parallel=dict(tensor=dict(size=8, mode='2.5d', depth=2))) - config_dict["3d"] = dict(parallel=dict(tensor=dict(size=8, mode='3d'))) - - return config_dict + Returns: + fwd_time (float): the average forward time taken by forward pass in second + bwd_time (float): the average backward time taken by forward pass in second + max_allocated (float): the maximum GPU memory allocated in GB + max_cached (float): the maximum GPU memory cached in GB + """ -def build_configs(args): - total_device_cnt = torch.cuda.device_count() - if args.num_gpus == -1: - config_dict = build_configs_helper(total_device_cnt) - else: - valid_device_cnt = min(args.num_gpus, total_device_cnt) - config_dict = build_configs_helper(valid_device_cnt) - return config_dict + def _run_step(data): + timer.start('forward') + out = model(data) + timer.stop('forward', keep_in_history=True) + timer.start('backward') + out.mean().backward() + timer.stop('backward', keep_in_history=True) -def profile_1d(input_tensor, config, args): - gpc.load_config(config) - gpc.init_parallel_groups() - assert gpc.is_initialized(ParallelMode.PARALLEL_1D) - model = MLP(args.hid_dim).cuda() - input_tensor = input_tensor.cuda() - torch.distributed.broadcast(input_tensor, src=0) - timer = Timer() - iter_times = args.num_steps - timer.start() - for i in range(iter_times): - input_tensor = model(input_tensor) - synchronize() - result_1d = timer.stop() - return result_1d + data_list = [data_func() for _ in range(warmup_steps)] + for data in data_list: + _run_step(data) + timer.reset('forward') + timer.reset('backward') -def profile_2d(input_tensor, config, args): - gpc.load_config(config) - gpc.init_parallel_groups() - assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) - assert gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW) - model = MLP(args.hid_dim).cuda() - input_tensor = input_tensor.cuda() - torch.distributed.broadcast(input_tensor, src=0) - input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] - input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] - timer = Timer() - iter_times = args.num_steps - timer.start() - for i in range(iter_times): - input_tensor = model(input_tensor) - synchronize() - result_2d = timer.stop() - return result_2d + for _ in range(profile_steps): + data = data_func() + _run_step(data) -def profile_2p5d(input_tensor, config, args): - gpc.load_config(config) - gpc.init_parallel_groups() - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) - model = MLP(args.hid_dim).cuda() - input_tensor = input_tensor.cuda() - torch.distributed.broadcast(input_tensor, src=0) - input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] - input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] - input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] - timer = Timer() - iter_times = args.num_steps - timer.start() - for i in range(iter_times): - input_tensor = model(input_tensor) - synchronize() - result_2p5d = timer.stop() - return result_2p5d + max_allocated, max_cached = get_memory_states() + fwd_time = timer.get_timer('forward').get_history_mean() + bwd_time = timer.get_timer('backward').get_history_mean() + return fwd_time, bwd_time, max_allocated, max_cached -def profile_3d(input_tensor, config, args): - gpc.load_config(config) - gpc.init_parallel_groups() - assert gpc.is_initialized(ParallelMode.PARALLEL_3D_WEIGHT) - assert gpc.is_initialized(ParallelMode.PARALLEL_3D_INPUT) - assert gpc.is_initialized(ParallelMode.PARALLEL_3D_OUTPUT) - model = MLP(args.hid_dim).cuda() - input_tensor = input_tensor.cuda() - torch.distributed.broadcast(input_tensor, src=0) - input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] - input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] - input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] - timer = Timer() - iter_times = args.num_steps - timer.start() - for i in range(iter_times): - input_tensor = model(input_tensor) - synchronize() - result_3d = timer.stop() - return result_3d + +def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: + """ + Return a random data of shape (batch_size, seq_length, dim) for profiling. + + Args: + dim (int): hidden size + batch_size (int): the number of data samples + seq_length (int): the number of tokens + mode (ParallelMode): Colossal-AI ParallelMode enum + + Returns: + data (torch.Tensor): random data + """ + + if mode in ['2d', '2.5d']: + batch_size = batch_size // 2 + dim = dim // 2 + elif mode == '3d': + batch_size = batch_size // 4 + dim = dim // 2 + + data = torch.rand(batch_size, seq_length, dim).cuda() + return data diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index 439ffa9a9..3e5b9ae63 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -1,8 +1,7 @@ import click from .launcher import run from .check import check -from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES -from colossalai.cli.benchmark.run import launch as col_benchmark +from .benchmark import benchmark class Arguments(): @@ -17,18 +16,6 @@ def cli(): pass -@click.command() -@click.option("--num_gpus", type=int, default=-1) -@click.option("--bs", type=int, default=BATCH_SIZE) -@click.option("--seq_len", type=int, default=SEQ_LENGTH) -@click.option("--hid_dim", type=int, default=HIDDEN_DIM) -@click.option("--num_steps", type=int, default=ITER_TIMES) -def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps): - args_dict = locals() - args = Arguments(args_dict) - col_benchmark(args) - - cli.add_command(run) cli.add_command(check) cli.add_command(benchmark)