[cli] added distributed launcher command (#791)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [CLI]add cli launcher feature

* remove testing message used during developing

* refactor the module structure.
pull/789/head^2
YuliangLiu0306 2022-04-19 10:59:44 +08:00 committed by GitHub
parent 97cd9b03b3
commit cfadc9df8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 489 additions and 0 deletions

View File

54
colossalai/cli/cli.py Normal file
View File

@ -0,0 +1,54 @@
import click
from colossalai.cli.launcher.run import main as col_launch
class Arguments():
def __init__(self, dict):
for k, v in dict.items():
self.__dict__[k] = v
@click.group()
def cli():
pass
@click.command()
@click.option("--hostfile",
type=str,
default="")
@click.option("--include",
type=str,
default="")
@click.option("--exclude",
type=str,
default="")
@click.option("--num_nodes",
type=int,
default=-1)
@click.option("--num_gpus",
type=int,
default=-1)
@click.option("--master_port",
type=int,
default=29500)
@click.option("--master_addr",
type=str,
default="127.0.0.1")
@click.option("--launcher",
type=str,
default="torch")
@click.option("--launcher_args",
type=str,
default="")
@click.argument("user_script",
type=str)
@click.argument('user_args', nargs=-1)
def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_port,
launcher, launcher_args, user_script, user_args):
args_dict = locals()
args = Arguments(args_dict)
args.user_args = list(args.user_args)
col_launch(args)
cli.add_command(launch)
if __name__ == '__main__':
cli()

View File

View File

@ -0,0 +1,146 @@
import os
import sys
import shutil
import subprocess
import warnings
from shlex import quote
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
class MultiNodeRunner(ABC):
def __init__(self, args):
self.args = args
self.user_arguments = self.args.user_args
self.user_script = args.user_script
self.exports = {}
@abstractmethod
def backend_exists(self):
"""Return whether the corresponding backend exists"""
@abstractmethod
def get_cmd(self, environment, active_devices):
"""Return the command to execute on node"""
def add_export(self, key, var):
self.exports[key.strip()] = var.strip()
@property
def name(self):
"""Return the name of the backend"""
return self.__class__.__name__
class PDSHRunner(MultiNodeRunner):
def __init__(self, args):
super().__init__(args)
def backend_exists(self):
return shutil.which('pdsh')
@property
def name(self):
return "pdsh"
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
def get_cmd(self, environment, active_devices, args):
environment['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_devices.keys())
logger.info("Running on the following workers: %s" % active_workers)
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
exports = ""
for key, val in self.exports.items():
exports += f"export {key}={quote(val)}; "
# https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command
colossal_launch = [
exports,
f"cd {os.path.abspath('.')};",
sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={args.num_gpus}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
]
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, device_pool):
super().__init__(args)
self.device_pool = device_pool
def backend_exists(self):
return shutil.which('ompi_info')
@property
def name(self):
return "openmpi"
def get_cmd(self, environment, active_devices):
total_process_count = sum(self.device_pool.values())
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}'
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class SLURMRunner(MultiNodeRunner):
def __init__(self, args):
super().__init__(args)
def backend_exists(self):
return shutil.which('slurm_info')
@property
def name(self):
return "slurm"
def get_cmd(self, environment, active_devices, args):
assert "-p" in args.launcher_args
srun_args = args.launcher_args.strip().split()
assert len(srun_args) >= 2, "we need more info about which partition to use."
partition_name = srun_args(srun_args.index("-p")+1)
slurm_cmd = [
'srun',
"-p",
f"{partition_name}",
"--nodes",
f"{args.num_nodes}",
"--tasks",
f"{args.num_gpus}"
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return slurm_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments

View File

@ -0,0 +1,285 @@
import argparse
from argparse import ArgumentParser, REMAINDER
import subprocess
import collections
import sys
import os
import torch
from colossalai.logging import get_dist_logger
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
def build_args_parser() -> ArgumentParser:
"""Helper function parsing the command line options."""
parser = ArgumentParser(description="colossal distributed training launcher")
parser.add_argument("-H",
"--hostfile",
type=str,
default="",
help="Hostfile path that defines the "
"device pool available to the job (e.g., "
"worker-name:number of slots)")
parser.add_argument("-i",
"--include",
type=str,
default="",
help="Specify computing devices to use during execution."
"String format is NODE_SPEC@NODE_SPEC"
"where NODE_SPEC=<worker-name>:<list-of-slots>")
parser.add_argument("-e",
"--exclude",
type=str,
default="",
help="Specify computing devices to NOT use during execution."
"Mutually exclusive with --include. Formatting"
"is the same as --include.")
parser.add_argument("--num_nodes",
type=int,
default=-1,
help="Total number of worker nodes to use.")
parser.add_argument("--num_gpus",
type=int,
default=-1,
help="Number of GPUs to use on each node.")
parser.add_argument("--master_port",
default=29500,
type=int,
help="(optional) Port used by PyTorch distributed for "
"communication during distributed training.")
parser.add_argument("--master_addr",
default="127.0.0.1",
type=str,
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")
parser.add_argument("--launcher",
default="torch",
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, SLURM.")
parser.add_argument("--launcher_args",
default="",
type=str,
help="(optional) pass launcher specific arguments as a "
"single quoted argument.")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
"arguments.")
parser.add_argument('user_args', nargs=argparse.REMAINDER)
return parser
def fetch_hostfile(hostfile_path):
logger = get_dist_logger()
if not os.path.isfile(hostfile_path):
logger.warning("Unable to find hostfile, will proceed with training "
"with local resources only")
return None
# e.g., worker-0:16
with open(hostfile_path, 'r') as fd:
device_pool = collections.OrderedDict()
for line in fd.readlines():
line = line.strip()
if line == '':
# skip empty lines
continue
try:
hostname, slot_count = line.split(":")
slot_count = int(slot_count)
except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to "
"proceed with training.")
raise err
device_pool[hostname] = slot_count
return device_pool
def _stable_remove_duplicates(data):
# Create a new list in the same order as original but with duplicates
# removed, should never be more than ~16 elements so simple is best
new_list = []
for x in data:
if x not in new_list:
new_list.append(x)
return new_list
def parse_device_filter(host_info, include_str="", exclude_str=""):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
slots [0, 2] on worker-1.
exclude_str="worker-1:0" will use all available devices except
slot 0 on worker-1.
'''
logger = get_dist_logger()
# Constants that define our syntax
NODE_SEP = '@'
SLOT_LIST_START = ':'
SLOT_SEP = ','
# Ensure include/exclude are mutually exclusive
if (include_str != "") and (exclude_str != ""):
raise ValueError('include_str and exclude_str are mutually exclusive.')
# no-op
if (include_str == "") and (exclude_str == ""):
return host_info
# Either build from scratch or remove items
filtered_hosts = dict()
if include_str:
parse_str = include_str
if exclude_str != "":
filtered_hosts = deepcopy(host_info)
parse_str = exclude_str
# foreach node in the list
for node_config in parse_str.split(NODE_SEP):
# Node can either be alone or node:slot,slot,slot
if SLOT_LIST_START in node_config:
hostname, slots = node_config.split(SLOT_LIST_START)
slots = [int(x) for x in slots.split(SLOT_SEP)]
# sanity checks
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
for slot in slots:
if slot not in host_info[hostname]:
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
# If include string, build the list from here
if include_str:
filtered_hosts[hostname] = slots
elif exclude_str:
for slot in slots:
logger.info(f'removing {slot} from {hostname}')
filtered_hosts[hostname].remove(slot)
# User just specified the whole node
else:
hostname = node_config
# sanity check hostname
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
if include_str:
filtered_hosts[hostname] = host_info[hostname]
elif exclude_str:
filtered_hosts[hostname] = []
# Post-processing to remove duplicates and empty nodes
del_keys = []
for hostname in filtered_hosts:
# Remove duplicates
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
# Remove empty hosts
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
for name in del_keys:
del filtered_hosts[name]
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
# we map ranks to nodes correctly by maintaining host_info ordering.
ordered_hosts = collections.OrderedDict()
for host in host_info:
if host in filtered_hosts:
ordered_hosts[host] = filtered_hosts[host]
return ordered_hosts
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
active_devices = collections.OrderedDict()
for hostname, slots in device_pool.items():
active_devices[hostname] = list(range(slots))
return parse_device_filter(active_devices,
include_str=inclusion,
exclude_str=exclusion)
def main(args=None):
logger = get_dist_logger()
assert args is not None, "args should not be None."
device_pool = fetch_hostfile(args.hostfile)
active_devices = None
if device_pool:
active_devices = parse_inclusion_exclusion(device_pool,
args.include,
args.exclude)
if args.num_nodes > 0:
updated_active_devices = collections.OrderedDict()
for count, hostname in enumerate(active_devices.keys()):
if args.num_nodes == count:
break
updated_active_devices[hostname] = active_devices[hostname]
active_devices = updated_active_devices
if args.num_gpus > 0:
updated_active_devices = collections.OrderedDict()
for hostname in active_devices.keys():
updated_active_devices[hostname] = list(range(args.num_gpus))
active_devices = updated_active_devices
env = os.environ.copy()
if not active_devices:
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = args.num_gpus
if torch.__version__ <= "1.09":
cmd = [sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
else:
cmd = ["torchrun",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
else:
if args.launcher == "torch":
runner = PDSHRunner(args)
elif args.launcher == "mpi":
runner = OpenMPIRunner(args, device_pool)
elif args.launcher == "slurm":
runner = SLURMRunner(args, device_pool)
else:
raise NotImplementedError(f"Unknown launcher {args.launcher}")
if not runner.backend_exists():
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env:
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
else:
env['PYTHONPATH'] = curr_path
cmd = runner.get_cmd(env, active_devices, args)
result = subprocess.Popen(cmd, env=env)
result.wait()
if result.returncode > 0:
sys.exit(result.returncode)
if __name__ == "__main__":
main()

View File

@ -213,6 +213,10 @@ setup(
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
install_requires=fetch_requirements('requirements/requirements.txt'),
entry_points='''
[console_scripts]
colossal=colossalai.cli.cli:cli
''',
python_requires='>=3.7',
classifiers=[
'Programming Language :: Python :: 3',