mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #807 from FrankLeeeee/feature/cli
[cli] fixed a bug in user args and refactored the module structurepull/810/head
commit
f6dcd23fb9
|
@ -1,77 +1,34 @@
|
||||||
import click
|
import click
|
||||||
from colossalai.cli.launcher.run import main as col_launch
|
from .launcher import run
|
||||||
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
from colossalai.cli.benchmark.utils import BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM, ITER_TIMES
|
||||||
from colossalai.cli.benchmark.run import launch as col_benchmark
|
from colossalai.cli.benchmark.run import launch as col_benchmark
|
||||||
|
|
||||||
|
|
||||||
class Arguments():
|
class Arguments():
|
||||||
|
|
||||||
def __init__(self, arg_dict):
|
def __init__(self, arg_dict):
|
||||||
for k, v in arg_dict.items():
|
for k, v in arg_dict.items():
|
||||||
self.__dict__[k] = v
|
self.__dict__[k] = v
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def cli():
|
def cli():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--num_gpus",
|
@click.option("--num_gpus", type=int, default=-1)
|
||||||
type=int,
|
@click.option("--bs", type=int, default=BATCH_SIZE)
|
||||||
default=-1)
|
@click.option("--seq_len", type=int, default=SEQ_LENGTH)
|
||||||
@click.option("--bs",
|
@click.option("--hid_dim", type=int, default=HIDDEN_DIM)
|
||||||
type=int,
|
@click.option("--num_steps", type=int, default=ITER_TIMES)
|
||||||
default=BATCH_SIZE)
|
|
||||||
@click.option("--seq_len",
|
|
||||||
type=int,
|
|
||||||
default=SEQ_LENGTH)
|
|
||||||
@click.option("--hid_dim",
|
|
||||||
type=int,
|
|
||||||
default=HIDDEN_DIM)
|
|
||||||
@click.option("--num_steps",
|
|
||||||
type=int,
|
|
||||||
default=ITER_TIMES)
|
|
||||||
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
|
def benchmark(num_gpus, bs, seq_len, hid_dim, num_steps):
|
||||||
args_dict = locals()
|
args_dict = locals()
|
||||||
args = Arguments(args_dict)
|
args = Arguments(args_dict)
|
||||||
col_benchmark(args)
|
col_benchmark(args)
|
||||||
|
|
||||||
@click.command()
|
|
||||||
@click.option("--hostfile",
|
|
||||||
type=str,
|
|
||||||
default="")
|
|
||||||
@click.option("--include",
|
|
||||||
type=str,
|
|
||||||
default="")
|
|
||||||
@click.option("--exclude",
|
|
||||||
type=str,
|
|
||||||
default="")
|
|
||||||
@click.option("--num_nodes",
|
|
||||||
type=int,
|
|
||||||
default=-1)
|
|
||||||
@click.option("--num_gpus",
|
|
||||||
type=int,
|
|
||||||
default=-1)
|
|
||||||
@click.option("--master_port",
|
|
||||||
type=int,
|
|
||||||
default=29500)
|
|
||||||
@click.option("--master_addr",
|
|
||||||
type=str,
|
|
||||||
default="127.0.0.1")
|
|
||||||
@click.option("--launcher",
|
|
||||||
type=str,
|
|
||||||
default="torch")
|
|
||||||
@click.option("--launcher_args",
|
|
||||||
type=str,
|
|
||||||
default="")
|
|
||||||
@click.argument("user_script",
|
|
||||||
type=str)
|
|
||||||
@click.argument('user_args', nargs=-1)
|
|
||||||
def launch(hostfile, num_nodes, num_gpus, include, exclude, master_addr, master_port,
|
|
||||||
launcher, launcher_args, user_script, user_args):
|
|
||||||
args_dict = locals()
|
|
||||||
args = Arguments(args_dict)
|
|
||||||
args.user_args = list(args.user_args)
|
|
||||||
col_launch(args)
|
|
||||||
|
|
||||||
cli.add_command(launch)
|
cli.add_command(run)
|
||||||
cli.add_command(benchmark)
|
cli.add_command(benchmark)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
import click
|
||||||
|
from .run import launch_multi_processes
|
||||||
|
from colossalai.context import Config
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(help="Launch distributed training on a single node or multiple nodes",
|
||||||
|
context_settings=dict(ignore_unknown_options=True))
|
||||||
|
@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch")
|
||||||
|
@click.option("--hostfile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)")
|
||||||
|
@click.option(
|
||||||
|
"--include",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=<worker-name>:<list-of-slots>"
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--exclude",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
|
||||||
|
)
|
||||||
|
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
|
||||||
|
@click.option("--nprocs_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
|
||||||
|
@click.option("--master_port",
|
||||||
|
type=int,
|
||||||
|
default=29500,
|
||||||
|
help="(optional) Port used by PyTorch distributed for communication during distributed training.")
|
||||||
|
@click.option("--master_addr",
|
||||||
|
type=str,
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
|
||||||
|
@click.option(
|
||||||
|
"--launcher",
|
||||||
|
type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False),
|
||||||
|
default="torch",
|
||||||
|
help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.")
|
||||||
|
@click.option("--launcher_args",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="(optional) pass launcher specific arguments as a single quoted argument.")
|
||||||
|
@click.argument("user_script", type=str)
|
||||||
|
@click.argument('user_args', nargs=-1)
|
||||||
|
def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include: str, exclude: str, master_addr: str,
|
||||||
|
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
|
||||||
|
"""
|
||||||
|
To launch multiple processes on a single node or multiple nodes via command line.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
# run on the current node with all available GPUs
|
||||||
|
colossalai run train.py
|
||||||
|
|
||||||
|
# run with only 2 GPUs on the current node
|
||||||
|
colossalai run --nprocs_per_node 2 train.py
|
||||||
|
|
||||||
|
# run on two nodes
|
||||||
|
colossalai run --host <host1>,<host2> train.py
|
||||||
|
|
||||||
|
# run with hostfile
|
||||||
|
colossalai run --hostfile <file_path> train.py
|
||||||
|
"""
|
||||||
|
args_dict = locals()
|
||||||
|
args = Config(args_dict)
|
||||||
|
args.user_args = list(args.user_args)
|
||||||
|
# (lsg) TODO: fix this function
|
||||||
|
# launch_multi_processes(args)
|
|
@ -6,79 +6,10 @@ import sys
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.context import Config
|
||||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
def build_args_parser() -> ArgumentParser:
|
|
||||||
"""Helper function parsing the command line options."""
|
|
||||||
|
|
||||||
parser = ArgumentParser(description="colossal distributed training launcher")
|
|
||||||
|
|
||||||
parser.add_argument("-H",
|
|
||||||
"--hostfile",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Hostfile path that defines the "
|
|
||||||
"device pool available to the job (e.g., "
|
|
||||||
"worker-name:number of slots)")
|
|
||||||
|
|
||||||
parser.add_argument("-i",
|
|
||||||
"--include",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Specify computing devices to use during execution."
|
|
||||||
"String format is NODE_SPEC@NODE_SPEC"
|
|
||||||
"where NODE_SPEC=<worker-name>:<list-of-slots>")
|
|
||||||
|
|
||||||
parser.add_argument("-e",
|
|
||||||
"--exclude",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Specify computing devices to NOT use during execution."
|
|
||||||
"Mutually exclusive with --include. Formatting"
|
|
||||||
"is the same as --include.")
|
|
||||||
|
|
||||||
parser.add_argument("--num_nodes",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Total number of worker nodes to use.")
|
|
||||||
|
|
||||||
parser.add_argument("--num_gpus",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Number of GPUs to use on each node.")
|
|
||||||
|
|
||||||
parser.add_argument("--master_port",
|
|
||||||
default=29500,
|
|
||||||
type=int,
|
|
||||||
help="(optional) Port used by PyTorch distributed for "
|
|
||||||
"communication during distributed training.")
|
|
||||||
|
|
||||||
parser.add_argument("--master_addr",
|
|
||||||
default="127.0.0.1",
|
|
||||||
type=str,
|
|
||||||
help="(optional) IP address of node 0, will be "
|
|
||||||
"inferred via 'hostname -I' if not specified.")
|
|
||||||
|
|
||||||
parser.add_argument("--launcher",
|
|
||||||
default="torch",
|
|
||||||
type=str,
|
|
||||||
help="(optional) choose launcher backend for multi-node "
|
|
||||||
"training. Options currently include PDSH, OpenMPI, SLURM.")
|
|
||||||
|
|
||||||
parser.add_argument("--launcher_args",
|
|
||||||
default="",
|
|
||||||
type=str,
|
|
||||||
help="(optional) pass launcher specific arguments as a "
|
|
||||||
"single quoted argument.")
|
|
||||||
|
|
||||||
parser.add_argument("user_script",
|
|
||||||
type=str,
|
|
||||||
help="User script to launch, followed by any required "
|
|
||||||
"arguments.")
|
|
||||||
|
|
||||||
parser.add_argument('user_args', nargs=argparse.REMAINDER)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
def fetch_hostfile(hostfile_path):
|
def fetch_hostfile(hostfile_path):
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -106,6 +37,7 @@ def fetch_hostfile(hostfile_path):
|
||||||
|
|
||||||
return device_pool
|
return device_pool
|
||||||
|
|
||||||
|
|
||||||
def _stable_remove_duplicates(data):
|
def _stable_remove_duplicates(data):
|
||||||
# Create a new list in the same order as original but with duplicates
|
# Create a new list in the same order as original but with duplicates
|
||||||
# removed, should never be more than ~16 elements so simple is best
|
# removed, should never be more than ~16 elements so simple is best
|
||||||
|
@ -115,6 +47,7 @@ def _stable_remove_duplicates(data):
|
||||||
new_list.append(x)
|
new_list.append(x)
|
||||||
return new_list
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||||
|
|
||||||
|
@ -202,26 +135,27 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||||
|
|
||||||
return ordered_hosts
|
return ordered_hosts
|
||||||
|
|
||||||
|
|
||||||
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
||||||
active_devices = collections.OrderedDict()
|
active_devices = collections.OrderedDict()
|
||||||
for hostname, slots in device_pool.items():
|
for hostname, slots in device_pool.items():
|
||||||
active_devices[hostname] = list(range(slots))
|
active_devices[hostname] = list(range(slots))
|
||||||
|
|
||||||
return parse_device_filter(active_devices,
|
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
|
||||||
include_str=inclusion,
|
|
||||||
exclude_str=exclusion)
|
|
||||||
|
|
||||||
def main(args=None):
|
|
||||||
logger = get_dist_logger()
|
def launch_multi_processes(args):
|
||||||
assert args is not None, "args should not be None."
|
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
|
||||||
|
|
||||||
device_pool = fetch_hostfile(args.hostfile)
|
# check
|
||||||
|
if args.hostfile:
|
||||||
|
device_pool = fetch_hostfile(args.hostfile)
|
||||||
|
else:
|
||||||
|
device_pool = None
|
||||||
|
|
||||||
active_devices = None
|
active_devices = None
|
||||||
if device_pool:
|
if device_pool:
|
||||||
active_devices = parse_inclusion_exclusion(device_pool,
|
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
|
||||||
args.include,
|
|
||||||
args.exclude)
|
|
||||||
if args.num_nodes > 0:
|
if args.num_nodes > 0:
|
||||||
updated_active_devices = collections.OrderedDict()
|
updated_active_devices = collections.OrderedDict()
|
||||||
for count, hostname in enumerate(active_devices.keys()):
|
for count, hostname in enumerate(active_devices.keys()):
|
||||||
|
@ -244,16 +178,15 @@ def main(args=None):
|
||||||
else:
|
else:
|
||||||
nproc_per_node = args.num_gpus
|
nproc_per_node = args.num_gpus
|
||||||
if torch.__version__ <= "1.09":
|
if torch.__version__ <= "1.09":
|
||||||
cmd = [sys.executable, "-u", "-m",
|
cmd = [
|
||||||
"torch.distributed.launch",
|
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||||
f"--nproc_per_node={nproc_per_node}",
|
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
|
||||||
f"--master_addr={args.master_addr}",
|
] + [args.user_script] + args.user_args
|
||||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
|
||||||
else:
|
else:
|
||||||
cmd = ["torchrun",
|
cmd = [
|
||||||
f"--nproc_per_node={nproc_per_node}",
|
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
|
||||||
f"--master_addr={args.master_addr}",
|
f"--master_port={args.master_port}"
|
||||||
f"--master_port={args.master_port}"] + [args.user_script] + args.user_args
|
] + [args.user_script] + args.user_args
|
||||||
else:
|
else:
|
||||||
if args.launcher == "torch":
|
if args.launcher == "torch":
|
||||||
runner = PDSHRunner(args)
|
runner = PDSHRunner(args)
|
||||||
|
@ -272,14 +205,10 @@ def main(args=None):
|
||||||
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
||||||
else:
|
else:
|
||||||
env['PYTHONPATH'] = curr_path
|
env['PYTHONPATH'] = curr_path
|
||||||
|
|
||||||
cmd = runner.get_cmd(env, active_devices, args)
|
cmd = runner.get_cmd(env, active_devices, args)
|
||||||
|
|
||||||
result = subprocess.Popen(cmd, env=env)
|
result = subprocess.Popen(cmd, env=env)
|
||||||
result.wait()
|
result.wait()
|
||||||
if result.returncode > 0:
|
if result.returncode > 0:
|
||||||
sys.exit(result.returncode)
|
sys.exit(result.returncode)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -215,7 +215,7 @@ setup(
|
||||||
install_requires=fetch_requirements('requirements/requirements.txt'),
|
install_requires=fetch_requirements('requirements/requirements.txt'),
|
||||||
entry_points='''
|
entry_points='''
|
||||||
[console_scripts]
|
[console_scripts]
|
||||||
colossal=colossalai.cli:cli
|
colossalai=colossalai.cli:cli
|
||||||
''',
|
''',
|
||||||
python_requires='>=3.7',
|
python_requires='>=3.7',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
|
Loading…
Reference in New Issue