[cli] refactored micro-benchmarking cli and added more metrics (#858)

pull/867/head
Frank Lee 2022-04-25 11:48:07 +08:00 committed by GitHub
parent ee222dfbf3
commit a82da26f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 278 additions and 251 deletions

View File

@ -1,2 +1,28 @@
from random import choices
import click
from .utils import *
from .run import *
from .benchmark import run_benchmark
from colossalai.context import Config
__all__ = ['benchmark']
@click.command()
@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
@click.option("-l", "--layers", type=int, default=2)
@click.option("-m",
"--model",
type=click.Choice(['mlp'], case_sensitive=False),
default='mlp',
help="Select the model to benchmark, currently only supports MLP")
def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
layers: int, model: str):
args_dict = locals()
args = Config(args_dict)
run_benchmark(args)

View File

@ -0,0 +1,94 @@
import colossalai
import click
import torch.multiprocessing as mp
from functools import partial
from typing import List, Dict
from colossalai.context import Config
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port, MultiTimer
from colossalai.cli.benchmark.utils import find_all_configs, profile_model, get_batch_data
from .models import MLP
def run_benchmark(args: Config) -> None:
"""
Run benchmarking with torch.multiprocessing.
"""
# sanity checks
if args.gpus is None:
click.echo("Error: --num_gpus is not given")
exit()
click.echo("=== Benchmarking Parameters ===")
for k, v in args.items():
click.echo(f'{k}: {v}')
click.echo('')
config_list = find_all_configs(args.gpus)
avail_ports = [free_port() for _ in range(len(config_list))]
run_func = partial(run_dist_profiling,
world_size=args.gpus,
port_list=avail_ports,
config_list=config_list,
hyperparams=args)
mp.spawn(run_func, nprocs=args.gpus)
def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
hyperparams: Config) -> None:
"""
A function executed for profiling, this function should be spawn by torch.multiprocessing.
Args:
rank (int): rank of the process
world_size (int): the number of processes
port_list (List[int]): a list of free ports for initializing distributed networks
config_list (List[Dict]): a list of configuration
hyperparams (Config): the hyperparameters given by the user
"""
# disable logging for clean output
disable_existing_loggers()
logger = get_dist_logger()
logger.set_level('WARNING')
for config, port in zip(config_list, port_list):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
timer = MultiTimer()
if hyperparams.model == 'mlp':
model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
else:
if gpc.get_global_rank() == 0:
click.echo("Error: Invalid argument for --model")
exit()
data_func = partial(get_batch_data,
dim=hyperparams.dimension,
batch_size=hyperparams.batch_size,
seq_length=hyperparams.seq_len,
mode=config.parallel.tensor.mode)
fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
warmup_steps=hyperparams.warmup_steps,
profile_steps=hyperparams.profile_steps,
data_func=data_func,
timer=timer)
gpc.destroy()
reset_seeds()
if gpc.get_global_rank() == 0:
config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
click.echo(f"=== {config_str} ===")
click.echo(f"Average forward time: {fwd_time}")
click.echo(f"Average backward time: {bwd_time}")
click.echo(f"Max allocated GPU memory: {max_allocated}")
click.echo(f"Max cached GPU memory: {max_cached}\n")

View File

@ -0,0 +1,17 @@
import torch
import colossalai.nn as col_nn
class MLP(torch.nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(col_nn.Linear(dim, dim))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x

View File

@ -1,86 +0,0 @@
import torch
import inspect
import os
import subprocess
import sys
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import print_rank_0
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port
from colossalai.cli.benchmark import build_args_parser, build_configs, \
build_input_tensor, profile_1d, profile_2d, profile_2p5d, profile_3d, \
BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
def launch(args=None):
train_script = inspect.getfile(inspect.currentframe())
assert args is not None, "args should not be None"
env = os.environ.copy()
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = args.num_gpus
train_args = [f"--num_gpus={nproc_per_node}"]
if args.bs != BATCH_SIZE:
train_args.append(f"--bs={args.bs}")
if args.hid_dim != HIDDEN_DIM:
train_args.append(f"--hid_dim={args.hid_dim}")
if args.num_steps != ITER_TIMES:
train_args.append(f"--num_steps={args.num_steps}")
if args.seq_len != SEQ_LENGTH:
train_args.append(f"--seq_len={args.seq_len}")
master_port = free_port()
if torch.__version__ <= "1.09":
cmd = [sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_port={master_port}"] + [train_script] + train_args
else:
cmd = ["torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--master_port={master_port}"] + [train_script] + train_args
result = subprocess.Popen(cmd, env=env)
result.wait()
if result.returncode > 0:
sys.exit(result.returncode)
def main():
parser = build_args_parser()
args = parser.parse_args()
disable_existing_loggers()
logger = get_dist_logger()
launch_from_torch(config={}, verbose=False)
input_tensor = build_input_tensor(args)
config_dict = build_configs(args)
if len(config_dict) == 0:
print_rank_0(f"WARNING: We need at least two devices to profile TP strategies performance.")
gpc.destroy()
return
for parallel_mode, config in config_dict.items():
if parallel_mode == "1d":
result_1d = profile_1d(input_tensor, config, args)
print_rank_0(f"INFO: Totoal time cost in 1D TP is {result_1d}.")
if parallel_mode == "2d":
result_2d = profile_2d(input_tensor, config, args)
print_rank_0(f"INFO: Totoal time cost in 2D TP is {result_2d}.")
if parallel_mode == "2p5d":
result_2p5d = profile_2p5d(input_tensor, config, args)
print_rank_0(f"INFO: Totoal time cost in 2P5D TP is {result_2p5d}.")
if parallel_mode == "3d":
result_3d = profile_3d(input_tensor, config, args)
print_rank_0(f"INFO: Totoal time cost in 3D TP is {result_3d}.")
if "2d" not in config_dict:
print_rank_0(f"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices.")
if "2p5d" not in config_dict:
print_rank_0(f"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices.")
print_rank_0(f"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices.")
gpc.destroy()
if __name__=="__main__":
main()

View File

@ -1,19 +0,0 @@
import torch
import colossalai
import colossalai.nn as col_nn
class MLP(torch.nn.Module):
def __init__(self, dim: int = 256):
super().__init__()
intermediate_dim = dim * 4
self.dense_1 = col_nn.Linear(dim, intermediate_dim)
self.activation = torch.nn.GELU()
self.dense_2 = col_nn.Linear(intermediate_dim, dim)
self.dropout = col_nn.Dropout(0.1)
def forward(self, x):
x = self.dense_1(x)
x = self.activation(x)
x = self.dense_2(x)
x = self.dropout(x)
return x

View File

@ -1,146 +1,154 @@
import math
import time
from grpc import Call
import torch
from .simple_model import MLP
from colossalai.utils import Timer, synchronize
from colossalai.utils import MultiTimer
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from argparse import ArgumentParser
from colossalai.context import ParallelMode, Config
from typing import List, Dict, Tuple, Callable
BATCH_SIZE = 8
SEQ_LENGTH = 120
HIDDEN_DIM = 1024
ITER_TIMES = 2000
def build_args_parser() -> ArgumentParser:
"""Helper function parsing the command line options."""
def get_time_stamp() -> int:
"""
Return the time stamp for profiling.
parser = ArgumentParser(description="colossal benchmark")
Returns:
time_stamp (int): the time given by time.time()
"""
parser.add_argument("--num_gpus",
type=int,
default=-1,
help="Total number of devices to use.")
parser.add_argument("--bs",
type=int,
default=BATCH_SIZE,
help="Batch size of the input tensor.")
parser.add_argument("--seq_len",
type=int,
default=SEQ_LENGTH,
help="Sequence length of the input tensor.")
parser.add_argument("--hid_dim",
type=int,
default=HIDDEN_DIM,
help="Hidden dimension of the input tensor.")
parser.add_argument("--num_steps",
type=int,
default=ITER_TIMES,
help="The number of iteration times.")
return parser
torch.cuda.synchronize()
time_stamp = time.time()
return time_stamp
def build_input_tensor(args):
return torch.rand(args.bs, args.seq_len, args.hid_dim)
def build_configs_helper(device_cnt: int):
config_dict = {}
def get_memory_states() -> Tuple[float]:
"""
Return the memory statistics.
if device_cnt < 2:
return config_dict
Returns:
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
"""
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
return max_allocated, max_cached
def find_all_configs(device_cnt: int) -> List[Dict]:
"""
Find all possible configurations for tensor parallelism
Args:
device_cnt (int): the number of devices
Returns:
config_list (List[Dict]): a list of configurations
"""
def _is_square(num):
return math.floor(math.sqrt(num))**2 == num
def _is_cube(num):
return math.floor(num**(1. / 3.))**3 == num
config_list = []
# add non-parallel config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
config_list.append(config)
# add 1D config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
config_list.append(config)
# add 1D config only if device_cnt is a square
if _is_square(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
config_list.append(config)
# check for 2.5D
# iterate over depth
for depth in range(1, device_cnt):
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
config_list.append(config)
# check for 3D if device_cnt is a cube
if _is_cube(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
config_list.append(config)
config_list = [Config(cfg) for cfg in config_list]
return config_list
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
timer: MultiTimer) -> Tuple[float]:
"""
Profile the forward and backward of a model
Args:
model (torch.nn.Module): a PyTorch model
warmup_steps (int): the number of steps for warmup
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
if device_cnt < 4:
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
elif device_cnt < 8:
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=4, mode='1d')))
config_dict["2d"] = dict(parallel=dict(tensor=dict(size=4, mode='2d')))
else:
config_dict["1d"] = dict(parallel=dict(tensor=dict(size=8, mode='1d')))
config_dict["2d"] = dict(parallel=dict(data=2, tensor=dict(size=4, mode='2d')))
config_dict["2p5d"] = dict(parallel=dict(tensor=dict(size=8, mode='2.5d', depth=2)))
config_dict["3d"] = dict(parallel=dict(tensor=dict(size=8, mode='3d')))
return config_dict
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second
max_allocated (float): the maximum GPU memory allocated in GB
max_cached (float): the maximum GPU memory cached in GB
"""
def build_configs(args):
total_device_cnt = torch.cuda.device_count()
if args.num_gpus == -1:
config_dict = build_configs_helper(total_device_cnt)
else:
valid_device_cnt = min(args.num_gpus, total_device_cnt)
config_dict = build_configs_helper(valid_device_cnt)
return config_dict
def _run_step(data):
timer.start('forward')
out = model(data)
timer.stop('forward', keep_in_history=True)
timer.start('backward')
out.mean().backward()
timer.stop('backward', keep_in_history=True)
def profile_1d(input_tensor, config, args):
gpc.load_config(config)
gpc.init_parallel_groups()
assert gpc.is_initialized(ParallelMode.PARALLEL_1D)
model = MLP(args.hid_dim).cuda()
input_tensor = input_tensor.cuda()
torch.distributed.broadcast(input_tensor, src=0)
timer = Timer()
iter_times = args.num_steps
timer.start()
for i in range(iter_times):
input_tensor = model(input_tensor)
synchronize()
result_1d = timer.stop()
return result_1d
data_list = [data_func() for _ in range(warmup_steps)]
for data in data_list:
_run_step(data)
timer.reset('forward')
timer.reset('backward')
def profile_2d(input_tensor, config, args):
gpc.load_config(config)
gpc.init_parallel_groups()
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL)
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW)
model = MLP(args.hid_dim).cuda()
input_tensor = input_tensor.cuda()
torch.distributed.broadcast(input_tensor, src=0)
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
timer = Timer()
iter_times = args.num_steps
timer.start()
for i in range(iter_times):
input_tensor = model(input_tensor)
synchronize()
result_2d = timer.stop()
return result_2d
for _ in range(profile_steps):
data = data_func()
_run_step(data)
def profile_2p5d(input_tensor, config, args):
gpc.load_config(config)
gpc.init_parallel_groups()
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL)
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW)
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP)
model = MLP(args.hid_dim).cuda()
input_tensor = input_tensor.cuda()
torch.distributed.broadcast(input_tensor, src=0)
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
timer = Timer()
iter_times = args.num_steps
timer.start()
for i in range(iter_times):
input_tensor = model(input_tensor)
synchronize()
result_2p5d = timer.stop()
return result_2p5d
max_allocated, max_cached = get_memory_states()
fwd_time = timer.get_timer('forward').get_history_mean()
bwd_time = timer.get_timer('backward').get_history_mean()
return fwd_time, bwd_time, max_allocated, max_cached
def profile_3d(input_tensor, config, args):
gpc.load_config(config)
gpc.init_parallel_groups()
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_WEIGHT)
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_INPUT)
assert gpc.is_initialized(ParallelMode.PARALLEL_3D_OUTPUT)
model = MLP(args.hid_dim).cuda()
input_tensor = input_tensor.cuda()
torch.distributed.broadcast(input_tensor, src=0)
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
input_tensor = torch.chunk(input_tensor, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
input_tensor = torch.chunk(input_tensor, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
timer = Timer()
iter_times = args.num_steps
timer.start()
for i in range(iter_times):
input_tensor = model(input_tensor)
synchronize()
result_3d = timer.stop()
return result_3d
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
"""
Return a random data of shape (batch_size, seq_length, dim) for profiling.
Args:
dim (int): hidden size
batch_size (int): the number of data samples
seq_length (int): the number of tokens
mode (ParallelMode): Colossal-AI ParallelMode enum
Returns:
data (torch.Tensor): random data
"""
if mode in ['2d', '2.5d']:
batch_size = batch_size // 2
dim = dim // 2
elif mode == '3d':
batch_size = batch_size // 4
dim = dim // 2
data = torch.rand(batch_size, seq_length, dim).cuda()
return data

View File

@ -1,8 +1,7 @@
import click
from .launcher import run
from .check import check
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
from colossalai.cli.benchmark.run import launch as col_benchmark
from .benchmark import benchmark
class Arguments():
@ -17,18 +16,6 @@ def cli():
pass
@click.command()
@click.option("--num_gpus", type=int, default=-1)
@click.option("--bs", type=int, default=BATCH_SIZE)
@click.option("--seq_len", type=int, default=SEQ_LENGTH)
@click.option("--hid_dim", type=int, default=HIDDEN_DIM)
@click.option("--num_steps", type=int, default=ITER_TIMES)
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
args_dict = locals()
args = Arguments(args_dict)
col_benchmark(args)
cli.add_command(run)
cli.add_command(check)
cli.add_command(benchmark)