mirror of https://github.com/hpcaitech/ColossalAI
[cli] added distributed launcher command (#791)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [CLI]add cli launcher feature
* remove testing message used during developing
* refactor the module structure.
pull/789/head^2
parent
97cd9b03b3
commit
cfadc9df8e
|
@ -0,0 +1,54 @@
|
|||
import click
|
||||
from colossalai.cli.launcher.run import main as col_launch
|
||||
|
||||
class Arguments():
|
||||
def __init__(self, dict):
|
||||
for k, v in dict.items():
|
||||
self.__dict__[k] = v
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option("--hostfile",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--include",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--exclude",
|
||||
type=str,
|
||||
default="")
|
||||
@click.option("--num_nodes",
|
||||
type=int,
|
||||
default=-1)
|
||||
@click.option("--num_gpus",
|
||||
type=int,
|
||||
default=-1)
|
||||
@click.option("--master_port",
|
||||
type=int,
|
||||
default=29500)
|
||||
@click.option("--master_addr",
|
||||
type=str,
|
||||
default="127.0.0.1")
|
||||
@click.option("--launcher",
|
||||
type=str,
|
||||
default="torch")
|
||||
@click.option("--launcher_args",
|
||||
type=str,
|
||||
default="")
|
||||
@click.argument("user_script",
|
||||
type=str)
|
||||
@click.argument('user_args', nargs=-1)
|
||||
def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_port,
|
||||
launcher, launcher_args, user_script, user_args):
|
||||
args_dict = locals()
|
||||
args = Arguments(args_dict)
|
||||
args.user_args = list(args.user_args)
|
||||
col_launch(args)
|
||||
|
||||
cli.add_command(launch)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
|
@ -0,0 +1,146 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from shlex import quote
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
class MultiNodeRunner(ABC):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.user_arguments = self.args.user_args
|
||||
self.user_script = args.user_script
|
||||
self.exports = {}
|
||||
|
||||
@abstractmethod
|
||||
def backend_exists(self):
|
||||
"""Return whether the corresponding backend exists"""
|
||||
|
||||
@abstractmethod
|
||||
def get_cmd(self, environment, active_devices):
|
||||
"""Return the command to execute on node"""
|
||||
|
||||
def add_export(self, key, var):
|
||||
self.exports[key.strip()] = var.strip()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the backend"""
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class PDSHRunner(MultiNodeRunner):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('pdsh')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "pdsh"
|
||||
|
||||
def parse_user_args(self):
|
||||
return list(
|
||||
map(lambda x: x if x.startswith("-") else f"'{x}'",
|
||||
self.args.user_args))
|
||||
|
||||
def get_cmd(self, environment, active_devices, args):
|
||||
environment['PDSH_RCMD_TYPE'] = 'ssh'
|
||||
|
||||
active_workers = ",".join(active_devices.keys())
|
||||
logger.info("Running on the following workers: %s" % active_workers)
|
||||
|
||||
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
|
||||
|
||||
exports = ""
|
||||
for key, val in self.exports.items():
|
||||
exports += f"export {key}={quote(val)}; "
|
||||
|
||||
# https://linux.die.net/man/1/pdsh
|
||||
# %n will be replaced by pdsh command
|
||||
colossal_launch = [
|
||||
exports,
|
||||
f"cd {os.path.abspath('.')};",
|
||||
sys.executable, "-u", "-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={args.num_gpus}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"
|
||||
]
|
||||
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
|
||||
|
||||
class OpenMPIRunner(MultiNodeRunner):
|
||||
def __init__(self, args, device_pool):
|
||||
super().__init__(args)
|
||||
self.device_pool = device_pool
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('ompi_info')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "openmpi"
|
||||
|
||||
def get_cmd(self, environment, active_devices):
|
||||
total_process_count = sum(self.device_pool.values())
|
||||
|
||||
mpirun_cmd = [
|
||||
'mpirun',
|
||||
'-n',
|
||||
f'{total_process_count}',
|
||||
'-hostfile',
|
||||
f'{self.args.hostfile}'
|
||||
]
|
||||
|
||||
export_cmd = []
|
||||
for k, v in self.exports.items():
|
||||
export_cmd += ['-x', f'{k}={quote(v)}']
|
||||
|
||||
python_exec = []
|
||||
python_exec = [sys.executable, "-u", "-m"]
|
||||
|
||||
return mpirun_cmd + export_cmd + python_exec + [self.user_script
|
||||
] + self.user_arguments
|
||||
|
||||
class SLURMRunner(MultiNodeRunner):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('slurm_info')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "slurm"
|
||||
|
||||
def get_cmd(self, environment, active_devices, args):
|
||||
|
||||
assert "-p" in args.launcher_args
|
||||
srun_args = args.launcher_args.strip().split()
|
||||
assert len(srun_args) >= 2, "we need more info about which partition to use."
|
||||
partition_name = srun_args(srun_args.index("-p")+1)
|
||||
slurm_cmd = [
|
||||
'srun',
|
||||
"-p",
|
||||
f"{partition_name}",
|
||||
"--nodes",
|
||||
f"{args.num_nodes}",
|
||||
"--tasks",
|
||||
f"{args.num_gpus}"
|
||||
]
|
||||
|
||||
export_cmd = []
|
||||
for k, v in self.exports.items():
|
||||
export_cmd += ['-x', f'{k}={quote(v)}']
|
||||
|
||||
python_exec = []
|
||||
python_exec = [sys.executable, "-u", "-m"]
|
||||
|
||||
return slurm_cmd + export_cmd + python_exec + [self.user_script
|
||||
] + self.user_arguments
|
|
@ -0,0 +1,285 @@
|
|||
import argparse
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
import subprocess
|
||||
import collections
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
||||
|
||||
def build_args_parser() -> ArgumentParser:
|
||||
"""Helper function parsing the command line options."""
|
||||
|
||||
parser = ArgumentParser(description="colossal distributed training launcher")
|
||||
|
||||
parser.add_argument("-H",
|
||||
"--hostfile",
|
||||
type=str,
|
||||
default="",
|
||||
help="Hostfile path that defines the "
|
||||
"device pool available to the job (e.g., "
|
||||
"worker-name:number of slots)")
|
||||
|
||||
parser.add_argument("-i",
|
||||
"--include",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify computing devices to use during execution."
|
||||
"String format is NODE_SPEC@NODE_SPEC"
|
||||
"where NODE_SPEC=<worker-name>:<list-of-slots>")
|
||||
|
||||
parser.add_argument("-e",
|
||||
"--exclude",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify computing devices to NOT use during execution."
|
||||
"Mutually exclusive with --include. Formatting"
|
||||
"is the same as --include.")
|
||||
|
||||
parser.add_argument("--num_nodes",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Total number of worker nodes to use.")
|
||||
|
||||
parser.add_argument("--num_gpus",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of GPUs to use on each node.")
|
||||
|
||||
parser.add_argument("--master_port",
|
||||
default=29500,
|
||||
type=int,
|
||||
help="(optional) Port used by PyTorch distributed for "
|
||||
"communication during distributed training.")
|
||||
|
||||
parser.add_argument("--master_addr",
|
||||
default="127.0.0.1",
|
||||
type=str,
|
||||
help="(optional) IP address of node 0, will be "
|
||||
"inferred via 'hostname -I' if not specified.")
|
||||
|
||||
parser.add_argument("--launcher",
|
||||
default="torch",
|
||||
type=str,
|
||||
help="(optional) choose launcher backend for multi-node "
|
||||
"training. Options currently include PDSH, OpenMPI, SLURM.")
|
||||
|
||||
parser.add_argument("--launcher_args",
|
||||
default="",
|
||||
type=str,
|
||||
help="(optional) pass launcher specific arguments as a "
|
||||
"single quoted argument.")
|
||||
|
||||
parser.add_argument("user_script",
|
||||
type=str,
|
||||
help="User script to launch, followed by any required "
|
||||
"arguments.")
|
||||
|
||||
parser.add_argument('user_args', nargs=argparse.REMAINDER)
|
||||
|
||||
return parser
|
||||
|
||||
def fetch_hostfile(hostfile_path):
|
||||
logger = get_dist_logger()
|
||||
if not os.path.isfile(hostfile_path):
|
||||
logger.warning("Unable to find hostfile, will proceed with training "
|
||||
"with local resources only")
|
||||
return None
|
||||
|
||||
# e.g., worker-0:16
|
||||
with open(hostfile_path, 'r') as fd:
|
||||
device_pool = collections.OrderedDict()
|
||||
for line in fd.readlines():
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
# skip empty lines
|
||||
continue
|
||||
try:
|
||||
hostname, slot_count = line.split(":")
|
||||
slot_count = int(slot_count)
|
||||
except ValueError as err:
|
||||
logger.error("Hostfile is not formatted correctly, unable to "
|
||||
"proceed with training.")
|
||||
raise err
|
||||
device_pool[hostname] = slot_count
|
||||
|
||||
return device_pool
|
||||
|
||||
def _stable_remove_duplicates(data):
|
||||
# Create a new list in the same order as original but with duplicates
|
||||
# removed, should never be more than ~16 elements so simple is best
|
||||
new_list = []
|
||||
for x in data:
|
||||
if x not in new_list:
|
||||
new_list.append(x)
|
||||
return new_list
|
||||
|
||||
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
|
||||
Examples:
|
||||
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
|
||||
slots [0, 2] on worker-1.
|
||||
exclude_str="worker-1:0" will use all available devices except
|
||||
slot 0 on worker-1.
|
||||
'''
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = '@'
|
||||
SLOT_LIST_START = ':'
|
||||
SLOT_SEP = ','
|
||||
|
||||
# Ensure include/exclude are mutually exclusive
|
||||
if (include_str != "") and (exclude_str != ""):
|
||||
raise ValueError('include_str and exclude_str are mutually exclusive.')
|
||||
|
||||
# no-op
|
||||
if (include_str == "") and (exclude_str == ""):
|
||||
return host_info
|
||||
|
||||
# Either build from scratch or remove items
|
||||
filtered_hosts = dict()
|
||||
if include_str:
|
||||
parse_str = include_str
|
||||
if exclude_str != "":
|
||||
filtered_hosts = deepcopy(host_info)
|
||||
parse_str = exclude_str
|
||||
|
||||
# foreach node in the list
|
||||
for node_config in parse_str.split(NODE_SEP):
|
||||
# Node can either be alone or node:slot,slot,slot
|
||||
if SLOT_LIST_START in node_config:
|
||||
hostname, slots = node_config.split(SLOT_LIST_START)
|
||||
slots = [int(x) for x in slots.split(SLOT_SEP)]
|
||||
|
||||
# sanity checks
|
||||
if hostname not in host_info:
|
||||
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
||||
for slot in slots:
|
||||
if slot not in host_info[hostname]:
|
||||
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
|
||||
|
||||
# If include string, build the list from here
|
||||
if include_str:
|
||||
filtered_hosts[hostname] = slots
|
||||
elif exclude_str:
|
||||
for slot in slots:
|
||||
logger.info(f'removing {slot} from {hostname}')
|
||||
filtered_hosts[hostname].remove(slot)
|
||||
|
||||
# User just specified the whole node
|
||||
else:
|
||||
hostname = node_config
|
||||
# sanity check hostname
|
||||
if hostname not in host_info:
|
||||
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
||||
|
||||
if include_str:
|
||||
filtered_hosts[hostname] = host_info[hostname]
|
||||
elif exclude_str:
|
||||
filtered_hosts[hostname] = []
|
||||
|
||||
# Post-processing to remove duplicates and empty nodes
|
||||
del_keys = []
|
||||
for hostname in filtered_hosts:
|
||||
# Remove duplicates
|
||||
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
|
||||
# Remove empty hosts
|
||||
if len(filtered_hosts[hostname]) == 0:
|
||||
del_keys.append(hostname)
|
||||
for name in del_keys:
|
||||
del filtered_hosts[name]
|
||||
|
||||
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
|
||||
# we map ranks to nodes correctly by maintaining host_info ordering.
|
||||
ordered_hosts = collections.OrderedDict()
|
||||
for host in host_info:
|
||||
if host in filtered_hosts:
|
||||
ordered_hosts[host] = filtered_hosts[host]
|
||||
|
||||
return ordered_hosts
|
||||
|
||||
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
||||
active_devices = collections.OrderedDict()
|
||||
for hostname, slots in device_pool.items():
|
||||
active_devices[hostname] = list(range(slots))
|
||||
|
||||
return parse_device_filter(active_devices,
|
||||
include_str=inclusion,
|
||||
exclude_str=exclusion)
|
||||
|
||||
def main(args=None):
|
||||
logger = get_dist_logger()
|
||||
assert args is not None, "args should not be None."
|
||||
|
||||
device_pool = fetch_hostfile(args.hostfile)
|
||||
|
||||
active_devices = None
|
||||
if device_pool:
|
||||
active_devices = parse_inclusion_exclusion(device_pool,
|
||||
args.include,
|
||||
args.exclude)
|
||||
if args.num_nodes > 0:
|
||||
updated_active_devices = collections.OrderedDict()
|
||||
for count, hostname in enumerate(active_devices.keys()):
|
||||
if args.num_nodes == count:
|
||||
break
|
||||
updated_active_devices[hostname] = active_devices[hostname]
|
||||
active_devices = updated_active_devices
|
||||
|
||||
if args.num_gpus > 0:
|
||||
updated_active_devices = collections.OrderedDict()
|
||||
for hostname in active_devices.keys():
|
||||
updated_active_devices[hostname] = list(range(args.num_gpus))
|
||||
active_devices = updated_active_devices
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
if not active_devices:
|
||||
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
else:
|
||||
nproc_per_node = args.num_gpus
|
||||
if torch.__version__ <= "1.09":
|
||||
cmd = [sys.executable, "-u", "-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
||||
else:
|
||||
cmd = ["torchrun",
|
||||
f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
||||
else:
|
||||
if args.launcher == "torch":
|
||||
runner = PDSHRunner(args)
|
||||
elif args.launcher == "mpi":
|
||||
runner = OpenMPIRunner(args, device_pool)
|
||||
elif args.launcher == "slurm":
|
||||
runner = SLURMRunner(args, device_pool)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown launcher {args.launcher}")
|
||||
|
||||
if not runner.backend_exists():
|
||||
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
|
||||
|
||||
curr_path = os.path.abspath('.')
|
||||
if 'PYTHONPATH' in env:
|
||||
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
||||
else:
|
||||
env['PYTHONPATH'] = curr_path
|
||||
|
||||
cmd = runner.get_cmd(env, active_devices, args)
|
||||
|
||||
result = subprocess.Popen(cmd, env=env)
|
||||
result.wait()
|
||||
if result.returncode > 0:
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
4
setup.py
4
setup.py
|
@ -213,6 +213,10 @@ setup(
|
|||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||
install_requires=fetch_requirements('requirements/requirements.txt'),
|
||||
entry_points='''
|
||||
[console_scripts]
|
||||
colossal=colossalai.cli.cli:cli
|
||||
''',
|
||||
python_requires='>=3.7',
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
|
|
Loading…
Reference in New Issue