[cli] fixed single-node process launching

pull/812/head
FrankLeeeee 2022-04-20 10:38:21 +08:00
parent 61c20b44bc
commit d522cb704e
3 changed files with 80 additions and 129 deletions

View File

@ -25,7 +25,7 @@ from colossalai.context import Config
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
)
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
@click.option("--nprocs_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
@click.option("--master_port",
type=int,
default=29500,
@ -45,7 +45,7 @@ from colossalai.context import Config
help="(optional) pass launcher specific arguments as a single quoted argument.")
@click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include: str, exclude: str, master_addr: str,
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
"""
To launch multiple processes on a single node or multiple nodes via command line.
@ -66,5 +66,4 @@ def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include:
args_dict = locals()
args = Config(args_dict)
args.user_args = list(args.user_args)
# (lsg) TODO: fix this function
# launch_multi_processes(args)
launch_multi_processes(args)

View File

@ -1,16 +1,14 @@
import os
import sys
import shutil
import subprocess
import warnings
from shlex import quote
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
class MultiNodeRunner(ABC):
def __init__(self, args):
self.args = args
self.user_arguments = self.args.user_args
@ -35,6 +33,7 @@ class MultiNodeRunner(ABC):
class PDSHRunner(MultiNodeRunner):
def __init__(self, args):
super().__init__(args)
@ -46,15 +45,13 @@ class PDSHRunner(MultiNodeRunner):
return "pdsh"
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
def get_cmd(self, environment, active_devices, args):
environment['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_devices.keys())
logger.info("Running on the following workers: %s" % active_workers)
print("Running on the following workers: %s" % active_workers)
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
@ -65,82 +62,8 @@ class PDSHRunner(MultiNodeRunner):
# https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command
colossal_launch = [
exports,
f"cd {os.path.abspath('.')};",
sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={args.num_gpus}",
f"--master_addr={args.master_addr}",
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
]
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, device_pool):
super().__init__(args)
self.device_pool = device_pool
def backend_exists(self):
return shutil.which('ompi_info')
@property
def name(self):
return "openmpi"
def get_cmd(self, environment, active_devices):
total_process_count = sum(self.device_pool.values())
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}'
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class SLURMRunner(MultiNodeRunner):
def __init__(self, args):
super().__init__(args)
def backend_exists(self):
return shutil.which('slurm_info')
@property
def name(self):
return "slurm"
def get_cmd(self, environment, active_devices, args):
assert "-p" in args.launcher_args
srun_args = args.launcher_args.strip().split()
assert len(srun_args) >= 2, "we need more info about which partition to use."
partition_name = srun_args(srun_args.index("-p")+1)
slurm_cmd = [
'srun',
"-p",
f"{partition_name}",
"--nodes",
f"{args.num_nodes}",
"--tasks",
f"{args.num_gpus}"
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return slurm_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments

View File

@ -1,22 +1,18 @@
import argparse
from argparse import ArgumentParser, REMAINDER
import click
import subprocess
import collections
import sys
import os
import torch
from colossalai.logging import get_dist_logger
from colossalai.context import Config
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
from .multinode_runner import PDSHRunner
from copy import deepcopy
def fetch_hostfile(hostfile_path):
logger = get_dist_logger()
if not os.path.isfile(hostfile_path):
logger.warning("Unable to find hostfile, will proceed with training "
"with local resources only")
return None
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
# e.g., worker-0:16
with open(hostfile_path, 'r') as fd:
@ -30,11 +26,13 @@ def fetch_hostfile(hostfile_path):
hostname, slot_count = line.split(":")
slot_count = int(slot_count)
except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to "
"proceed with training.")
raise err
device_pool[hostname] = slot_count
click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
exit()
if hostname in device_pool:
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
exit()
device_pool[hostname] = slot_count
return device_pool
@ -48,7 +46,7 @@ def _stable_remove_duplicates(data):
return new_list
def parse_device_filter(host_info, include_str="", exclude_str=""):
def parse_device_filter(host_info, include_str=None, exclude_str=None):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
@ -58,26 +56,25 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
slot 0 on worker-1.
'''
logger = get_dist_logger()
# Constants that define our syntax
NODE_SEP = '@'
SLOT_LIST_START = ':'
SLOT_SEP = ','
# Ensure include/exclude are mutually exclusive
if (include_str != "") and (exclude_str != ""):
raise ValueError('include_str and exclude_str are mutually exclusive.')
if include_str and exclude_str:
click.echo("--include and --exclude are mutually exclusive, only one can be used")
exit()
# no-op
if (include_str == "") and (exclude_str == ""):
if include_str is None and exclude_str is None:
return host_info
# Either build from scratch or remove items
filtered_hosts = dict()
if include_str:
parse_str = include_str
if exclude_str != "":
elif exclude_str:
filtered_hosts = deepcopy(host_info)
parse_str = exclude_str
@ -90,17 +87,18 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
# sanity checks
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
click.echo(f"Hostname '{hostname}' not found in hostfile")
exit()
for slot in slots:
if slot not in host_info[hostname]:
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
click.echo(f"No slot '{slot}' specified on host '{hostname}'")
# If include string, build the list from here
if include_str:
filtered_hosts[hostname] = slots
elif exclude_str:
for slot in slots:
logger.info(f'removing {slot} from {hostname}')
click.echo(f'- removing {slot} from {hostname}')
filtered_hosts[hostname].remove(slot)
# User just specified the whole node
@ -108,7 +106,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
hostname = node_config
# sanity check hostname
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
click.echo(f"Hostname '{hostname}' not found in hostfile")
exit()
if include_str:
filtered_hosts[hostname] = host_info[hostname]
@ -123,6 +122,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
# Remove empty hosts
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
# remove unneeded hosts
for name in del_keys:
del filtered_hosts[name]
@ -145,18 +146,40 @@ def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
def launch_multi_processes(args):
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
"""
Launch multiple processes on a single node or multiple nodes.
# check
The overall logic can be summarized as the pseudo code below:
if hostfile given:
hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo)
launch_on_multi_nodes(hostinfo)
elif hosts given:
hostinfo = parse_hosts(hosts)
launch_on_multi_nodes(hostinfo)
else:
launch_on_current_node()
"""
assert isinstance(args, Config)
# cannot accept hosts and hostfile at the same time
if args.host and args.hostfile:
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
# check if hostfile is given
if args.hostfile:
device_pool = fetch_hostfile(args.hostfile)
else:
device_pool = None
# filter and only keep the ones needed
active_devices = None
if device_pool:
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
if args.num_nodes > 0:
# only keep the first num_nodes to execute jobs
updated_active_devices = collections.OrderedDict()
for count, hostname in enumerate(active_devices.keys()):
if args.num_nodes == count:
@ -164,20 +187,36 @@ def launch_multi_processes(args):
updated_active_devices[hostname] = active_devices[hostname]
active_devices = updated_active_devices
if args.num_gpus > 0:
if args.nproc_per_node > 0:
# only keep the first
updated_active_devices = collections.OrderedDict()
for hostname in active_devices.keys():
updated_active_devices[hostname] = list(range(args.num_gpus))
for hostname, active_devices in active_devices.items():
if len(active_devices) < args.nproc_per_node:
click.echo(
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
)
exit()
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
active_devices = updated_active_devices
env = os.environ.copy()
# use hosts if hostfile is not given
if args.host and active_devices is None:
hostinfo = collections.OrderedDict()
host_list = args.host.strip().split(',')
for hostname in host_list:
hostinfo[hostname] = args.nproc_per_node
active_devices = hostinfo
# run on local node if not hosts or hostfile is given
if not active_devices:
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = args.num_gpus
if torch.__version__ <= "1.09":
nproc_per_node = args.nproc_per_node
if torch.__version__ <= "1.9":
cmd = [
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
@ -188,17 +227,7 @@ def launch_multi_processes(args):
f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args
else:
if args.launcher == "torch":
runner = PDSHRunner(args)
elif args.launcher == "mpi":
runner = OpenMPIRunner(args, device_pool)
elif args.launcher == "slurm":
runner = SLURMRunner(args, device_pool)
else:
raise NotImplementedError(f"Unknown launcher {args.launcher}")
if not runner.backend_exists():
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
runner = PDSHRunner(args)
curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env: