mirror of https://github.com/hpcaitech/ColossalAI
87 lines
3.5 KiB
Python
87 lines
3.5 KiB
Python
|
import torch
|
||
|
import inspect
|
||
|
import os
|
||
|
import subprocess
|
||
|
import sys
|
||
|
|
||
|
from colossalai.initialize import launch_from_torch
|
||
|
from colossalai.logging import disable_existing_loggers
|
||
|
from colossalai.utils import print_rank_0
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.logging import get_dist_logger
|
||
|
from colossalai.utils import free_port
|
||
|
from colossalai.cli.benchmark import build_args_parser, build_configs, \
|
||
|
build_input_tensor, profile_1d, profile_2d, profile_2p5d, profile_3d, \
|
||
|
BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||
|
|
||
|
|
||
|
def launch(args=None):
|
||
|
train_script = inspect.getfile(inspect.currentframe())
|
||
|
assert args is not None, "args should not be None"
|
||
|
env = os.environ.copy()
|
||
|
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
|
||
|
nproc_per_node = torch.cuda.device_count()
|
||
|
else:
|
||
|
nproc_per_node = args.num_gpus
|
||
|
|
||
|
train_args = [f"--num_gpus={nproc_per_node}"]
|
||
|
if args.bs != BATCH_SIZE:
|
||
|
train_args.append(f"--bs={args.bs}")
|
||
|
if args.hid_dim != HIDDEN_DIM:
|
||
|
train_args.append(f"--hid_dim={args.hid_dim}")
|
||
|
if args.num_steps != ITER_TIMES:
|
||
|
train_args.append(f"--num_steps={args.num_steps}")
|
||
|
if args.seq_len != SEQ_LENGTH:
|
||
|
train_args.append(f"--seq_len={args.seq_len}")
|
||
|
|
||
|
master_port = free_port()
|
||
|
if torch.__version__ <= "1.09":
|
||
|
cmd = [sys.executable, "-u", "-m",
|
||
|
"torch.distributed.launch",
|
||
|
f"--nproc_per_node={nproc_per_node}",
|
||
|
f"--master_port={master_port}"] + [train_script] + train_args
|
||
|
else:
|
||
|
cmd = ["torchrun",
|
||
|
f"--nproc_per_node={nproc_per_node}",
|
||
|
f"--master_port={master_port}"] + [train_script] + train_args
|
||
|
|
||
|
result = subprocess.Popen(cmd, env=env)
|
||
|
result.wait()
|
||
|
if result.returncode > 0:
|
||
|
sys.exit(result.returncode)
|
||
|
|
||
|
def main():
|
||
|
parser = build_args_parser()
|
||
|
args = parser.parse_args()
|
||
|
disable_existing_loggers()
|
||
|
logger = get_dist_logger()
|
||
|
launch_from_torch(config={}, verbose=False)
|
||
|
input_tensor = build_input_tensor(args)
|
||
|
config_dict = build_configs(args)
|
||
|
if len(config_dict) == 0:
|
||
|
print_rank_0(f"WARNING: We need at least two devices to profile TP strategies performance.")
|
||
|
gpc.destroy()
|
||
|
return
|
||
|
for parallel_mode, config in config_dict.items():
|
||
|
if parallel_mode == "1d":
|
||
|
result_1d = profile_1d(input_tensor, config, args)
|
||
|
print_rank_0(f"INFO: Totoal time cost in 1D TP is {result_1d}.")
|
||
|
if parallel_mode == "2d":
|
||
|
result_2d = profile_2d(input_tensor, config, args)
|
||
|
print_rank_0(f"INFO: Totoal time cost in 2D TP is {result_2d}.")
|
||
|
if parallel_mode == "2p5d":
|
||
|
result_2p5d = profile_2p5d(input_tensor, config, args)
|
||
|
print_rank_0(f"INFO: Totoal time cost in 2P5D TP is {result_2p5d}.")
|
||
|
if parallel_mode == "3d":
|
||
|
result_3d = profile_3d(input_tensor, config, args)
|
||
|
print_rank_0(f"INFO: Totoal time cost in 3D TP is {result_3d}.")
|
||
|
if "2d" not in config_dict:
|
||
|
print_rank_0(f"WARNING: To use 2D tensor parallel, you have to provide at least 4 computing devices.")
|
||
|
if "2p5d" not in config_dict:
|
||
|
print_rank_0(f"WARNING: To use 2P5D tensor parallel, you have to provide at least 8 computing devices.")
|
||
|
print_rank_0(f"WARNING: To use 3D tensor parallel, you have to provide at least 8 computing devices.")
|
||
|
gpc.destroy()
|
||
|
|
||
|
if __name__=="__main__":
|
||
|
main()
|