From de2f581d43ba403808c6b5eb365f7c44a375fc70 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 19 Apr 2022 12:08:28 +0800 Subject: [PATCH] [cli] added micro benchmarking for tp (#789) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [CLI]add cli benchmark feature * fix CodeFactor issues. * refactor the module structure. --- colossalai/cli/benchmark/__init__.py | 2 + colossalai/cli/benchmark/run.py | 86 +++++++++++++ colossalai/cli/benchmark/simple_model.py | 19 +++ colossalai/cli/benchmark/utils.py | 146 +++++++++++++++++++++++ colossalai/cli/cli.py | 28 ++++- 5 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 colossalai/cli/benchmark/__init__.py create mode 100644 colossalai/cli/benchmark/run.py create mode 100644 colossalai/cli/benchmark/simple_model.py create mode 100644 colossalai/cli/benchmark/utils.py diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py new file mode 100644 index 000000000..a8afb98e3 --- /dev/null +++ b/colossalai/cli/benchmark/__init__.py @@ -0,0 +1,2 @@ +from .utils import * +from .run import * diff --git a/colossalai/cli/benchmark/run.py b/colossalai/cli/benchmark/run.py new file mode 100644 index 000000000..825e9d8d1 --- /dev/null +++ b/colossalai/cli/benchmark/run.py @@ -0,0 +1,86 @@ +import torch +import inspect +import os +import subprocess +import sys + +from colossalai.initialize import launch_from_torch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import print_rank_0 +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.utils import free_port +from colossalai.cli.benchmark import build_args_parser, build_configs, \ + build_input_tensor, profile_1d, profile_2d, profile_2p5d, profile_3d, \ + BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES + + +def launch(args=None): + train_script = inspect.getfile(inspect.currentframe()) + assert args is not None, "args should not be None" + env = os.environ.copy() + if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count(): + nproc_per_node = torch.cuda.device_count() + else: + nproc_per_node = args.num_gpus + + train_args = [f"--num_gpus={nproc_per_node}"] + if args.bs != BATCH_SIZE: + train_args.append(f"--bs={args.bs}") + if args.hid_dim != HIDDEN_DIM: + train_args.append(f"--hid_dim={args.hid_dim}") + if args.num_steps != ITER_TIMES: + train_args.append(f"--num_steps={args.num_steps}") + if args.seq_len != SEQ_LENGTH: + train_args.append(f"--seq_len={args.seq_len}") + + master_port = free_port() + if torch.__version__ <= "1.09": + cmd = [sys.executable, "-u", "-m", + "torch.distributed.launch", + f"--nproc_per_node={nproc_per_node}", + f"--master_port={master_port}"] + [train_script] + train_args + else: + cmd = ["torchrun", + f"--nproc_per_node={nproc_per_node}", + f"--master_port={master_port}"] + [train_script] + train_args + + result = subprocess.Popen(cmd, env=env) + result.wait() + if result.returncode > 0: + sys.exit(result.returncode) + +def main(): + parser = build_args_parser() + args = parser.parse_args() + disable_existing_loggers() + logger = get_dist_logger() + launch_from_torch(config={}, verbose=False) + input_tensor = build_input_tensor(args) + config_dict = build_configs(args) + if len(config_dict) == 0: + print_rank_0(f"WARNING: We need at least two devices to profile TP strategies performance.") + gpc.destroy() + return + for parallel_mode, config in config_dict.items(): + if parallel_mode == "1d": + result_1d = profile_1d(input_tensor, config, args) + print_rank_0(f"INFO: Totoal time cost in 1D TP is {result_1d}.") + if parallel_mode == "2d": + result_2d = profile_2d(input_tensor, config, args) + print_rank_0(f"INFO: Totoal time cost in 2D TP is {result_2d}.") + if parallel_mode == "2p5d": + result_2p5d = profile_2p5d(input_tensor, config, args) + print_rank_0(f"INFO: Totoal time cost in 2P5D TP is {result_2p5d}.") + if parallel_mode == "3d": + result_3d = profile_3d(input_tensor, config, args) + print_rank_0(f"INFO: Totoal time cost in 3D TP is {result_3d}.") + if "2d" not in config_dict: + print_rank_0(f"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices.") + if "2p5d" not in config_dict: + print_rank_0(f"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices.") + print_rank_0(f"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices.") + gpc.destroy() + +if __name__=="__main__": + main() diff --git a/colossalai/cli/benchmark/simple_model.py b/colossalai/cli/benchmark/simple_model.py new file mode 100644 index 000000000..6ae6973bf --- /dev/null +++ b/colossalai/cli/benchmark/simple_model.py @@ -0,0 +1,19 @@ +import torch +import colossalai +import colossalai.nn as col_nn + +class MLP(torch.nn.Module): + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = col_nn.Linear(dim, intermediate_dim) + self.activation = torch.nn.GELU() + self.dense_2 = col_nn.Linear(intermediate_dim, dim) + self.dropout = col_nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py new file mode 100644 index 000000000..de96a81ac --- /dev/null +++ b/colossalai/cli/benchmark/utils.py @@ -0,0 +1,146 @@ +import torch +from .simple_model import MLP +from colossalai.utils import Timer, synchronize +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode +from argparse import ArgumentParser + +BATCH_SIZE = 8 +SEQ_LENGTH = 120 +HIDDEN_DIM = 1024 +ITER_TIMES = 2000 + +def build_args_parser() -> ArgumentParser: + """Helper function parsing the command line options.""" + + parser = ArgumentParser(description="colossal benchmark") + + parser.add_argument("--num_gpus", + type=int, + default=-1, + help="Total number of devices to use.") + parser.add_argument("--bs", + type=int, + default=BATCH_SIZE, + help="Batch size of the input tensor.") + parser.add_argument("--seq_len", + type=int, + default=SEQ_LENGTH, + help="Sequence length of the input tensor.") + parser.add_argument("--hid_dim", + type=int, + default=HIDDEN_DIM, + help="Hidden dimension of the input tensor.") + parser.add_argument("--num_steps", + type=int, + default=ITER_TIMES, + help="The number of iteration times.") + return parser + +def build_input_tensor(args): + return torch.rand(args.bs, args.seq_len, args.hid_dim) + +def build_configs_helper(device_cnt: int): + config_dict = {} + + if device_cnt < 2: + return config_dict + + if device_cnt < 4: + config_dict["1d"] = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) + elif device_cnt < 8: + config_dict["1d"] = dict(parallel=dict(tensor=dict(size=4, mode='1d'))) + config_dict["2d"] = dict(parallel=dict(tensor=dict(size=4, mode='2d'))) + else: + config_dict["1d"] = dict(parallel=dict(tensor=dict(size=8, mode='1d'))) + config_dict["2d"] = dict(parallel=dict(data=2, tensor=dict(size=4, mode='2d'))) + config_dict["2p5d"] = dict(parallel=dict(tensor=dict(size=8, mode='2.5d', depth=2))) + config_dict["3d"] = dict(parallel=dict(tensor=dict(size=8, mode='3d'))) + + return config_dict + +def build_configs(args): + total_device_cnt = torch.cuda.device_count() + if args.num_gpus == -1: + config_dict = build_configs_helper(total_device_cnt) + else: + valid_device_cnt = min(args.num_gpus, total_device_cnt) + config_dict = build_configs_helper(valid_device_cnt) + return config_dict + +def profile_1d(input_tensor, config, args): + gpc.load_config(config) + gpc.init_parallel_groups() + assert gpc.is_initialized(ParallelMode.PARALLEL_1D) + model = MLP(args.hid_dim).cuda() + input_tensor = input_tensor.cuda() + torch.distributed.broadcast(input_tensor, src=0) + timer = Timer() + iter_times = args.num_steps + timer.start() + for i in range(iter_times): + input_tensor = model(input_tensor) + synchronize() + result_1d = timer.stop() + return result_1d + +def profile_2d(input_tensor, config, args): + gpc.load_config(config) + gpc.init_parallel_groups() + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW) + model = MLP(args.hid_dim).cuda() + input_tensor = input_tensor.cuda() + torch.distributed.broadcast(input_tensor, src=0) + input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)] + input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)] + timer = Timer() + iter_times = args.num_steps + timer.start() + for i in range(iter_times): + input_tensor = model(input_tensor) + synchronize() + result_2d = timer.stop() + return result_2d + +def profile_2p5d(input_tensor, config, args): + gpc.load_config(config) + gpc.init_parallel_groups() + assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) + assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) + assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) + model = MLP(args.hid_dim).cuda() + input_tensor = input_tensor.cuda() + torch.distributed.broadcast(input_tensor, src=0) + input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)] + input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)] + input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)] + timer = Timer() + iter_times = args.num_steps + timer.start() + for i in range(iter_times): + input_tensor = model(input_tensor) + synchronize() + result_2p5d = timer.stop() + return result_2p5d + +def profile_3d(input_tensor, config, args): + gpc.load_config(config) + gpc.init_parallel_groups() + assert gpc.is_initialized(ParallelMode.PARALLEL_3D_WEIGHT) + assert gpc.is_initialized(ParallelMode.PARALLEL_3D_INPUT) + assert gpc.is_initialized(ParallelMode.PARALLEL_3D_OUTPUT) + model = MLP(args.hid_dim).cuda() + input_tensor = input_tensor.cuda() + torch.distributed.broadcast(input_tensor, src=0) + input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)] + input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)] + input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)] + timer = Timer() + iter_times = args.num_steps + timer.start() + for i in range(iter_times): + input_tensor = model(input_tensor) + synchronize() + result_3d = timer.stop() + return result_3d diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index ee04572fa..67dbf967c 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -1,15 +1,38 @@ import click from colossalai.cli.launcher.run import main as col_launch +from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES +from colossalai.cli.benchmark.run import launch as col_benchmark class Arguments(): - def __init__(self, dict): - for k, v in dict.items(): + def __init__(self, arg_dict): + for k, v in arg_dict.items(): self.__dict__[k] = v @click.group() def cli(): pass +@click.command() +@click.option("--num_gpus", + type=int, + default=-1) +@click.option("--bs", + type=int, + default=BATCH_SIZE) +@click.option("--seq_len", + type=int, + default=SEQ_LENGTH) +@click.option("--hid_dim", + type=int, + default=HIDDEN_DIM) +@click.option("--num_steps", + type=int, + default=ITER_TIMES) +def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps): + args_dict = locals() + args = Arguments(args_dict) + col_benchmark(args) + @click.command() @click.option("--hostfile", type=str, @@ -49,6 +72,7 @@ def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_ col_launch(args) cli.add_command(launch) +cli.add_command(benchmark) if __name__ == '__main__': cli()