[cli] fixed single-node process launching

pull/812/head
FrankLeeeee 2022-04-20 10:38:21 +08:00
parent 61c20b44bc
commit d522cb704e
3 changed files with 80 additions and 129 deletions

View File

@ -25,7 +25,7 @@ from colossalai.context import Config
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include." "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
) )
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.") @click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
@click.option("--nprocs_per_node", type=int, default=-1, help="Number of GPUs to use on each node.") @click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
@click.option("--master_port", @click.option("--master_port",
type=int, type=int,
default=29500, default=29500,
@ -45,7 +45,7 @@ from colossalai.context import Config
help="(optional) pass launcher specific arguments as a single quoted argument.") help="(optional) pass launcher specific arguments as a single quoted argument.")
@click.argument("user_script", type=str) @click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1) @click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include: str, exclude: str, master_addr: str, def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str): master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
""" """
To launch multiple processes on a single node or multiple nodes via command line. To launch multiple processes on a single node or multiple nodes via command line.
@ -66,5 +66,4 @@ def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include:
args_dict = locals() args_dict = locals()
args = Config(args_dict) args = Config(args_dict)
args.user_args = list(args.user_args) args.user_args = list(args.user_args)
# (lsg) TODO: fix this function launch_multi_processes(args)
# launch_multi_processes(args)

View File

@ -1,16 +1,14 @@
import os import os
import sys import sys
import shutil import shutil
import subprocess
import warnings
from shlex import quote from shlex import quote
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
logger = get_dist_logger()
class MultiNodeRunner(ABC): class MultiNodeRunner(ABC):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.user_arguments = self.args.user_args self.user_arguments = self.args.user_args
@ -35,6 +33,7 @@ class MultiNodeRunner(ABC):
class PDSHRunner(MultiNodeRunner): class PDSHRunner(MultiNodeRunner):
def __init__(self, args): def __init__(self, args):
super().__init__(args) super().__init__(args)
@ -46,15 +45,13 @@ class PDSHRunner(MultiNodeRunner):
return "pdsh" return "pdsh"
def parse_user_args(self): def parse_user_args(self):
return list( return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
def get_cmd(self, environment, active_devices, args): def get_cmd(self, environment, active_devices, args):
environment['PDSH_RCMD_TYPE'] = 'ssh' environment['PDSH_RCMD_TYPE'] = 'ssh'
active_workers = ",".join(active_devices.keys()) active_workers = ",".join(active_devices.keys())
logger.info("Running on the following workers: %s" % active_workers) print("Running on the following workers: %s" % active_workers)
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers] pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
@ -65,82 +62,8 @@ class PDSHRunner(MultiNodeRunner):
# https://linux.die.net/man/1/pdsh # https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command # %n will be replaced by pdsh command
colossal_launch = [ colossal_launch = [
exports, exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
f"cd {os.path.abspath('.')};", f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
sys.executable, "-u", "-m",
"torch.distributed.launch",
f"--nproc_per_node={args.num_gpus}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}" f"--master_port={args.master_port}"
] ]
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, device_pool):
super().__init__(args)
self.device_pool = device_pool
def backend_exists(self):
return shutil.which('ompi_info')
@property
def name(self):
return "openmpi"
def get_cmd(self, environment, active_devices):
total_process_count = sum(self.device_pool.values())
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}'
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
class SLURMRunner(MultiNodeRunner):
def __init__(self, args):
super().__init__(args)
def backend_exists(self):
return shutil.which('slurm_info')
@property
def name(self):
return "slurm"
def get_cmd(self, environment, active_devices, args):
assert "-p" in args.launcher_args
srun_args = args.launcher_args.strip().split()
assert len(srun_args) >= 2, "we need more info about which partition to use."
partition_name = srun_args(srun_args.index("-p")+1)
slurm_cmd = [
'srun',
"-p",
f"{partition_name}",
"--nodes",
f"{args.num_nodes}",
"--tasks",
f"{args.num_gpus}"
]
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={quote(v)}']
python_exec = []
python_exec = [sys.executable, "-u", "-m"]
return slurm_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments

View File

@ -1,22 +1,18 @@
import argparse import click
from argparse import ArgumentParser, REMAINDER
import subprocess import subprocess
import collections import collections
import sys import sys
import os import os
import torch import torch
from colossalai.logging import get_dist_logger
from colossalai.context import Config from colossalai.context import Config
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner from .multinode_runner import PDSHRunner
from copy import deepcopy from copy import deepcopy
def fetch_hostfile(hostfile_path): def fetch_hostfile(hostfile_path):
logger = get_dist_logger()
if not os.path.isfile(hostfile_path): if not os.path.isfile(hostfile_path):
logger.warning("Unable to find hostfile, will proceed with training " click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
"with local resources only") exit()
return None
# e.g., worker-0:16 # e.g., worker-0:16
with open(hostfile_path, 'r') as fd: with open(hostfile_path, 'r') as fd:
@ -30,11 +26,13 @@ def fetch_hostfile(hostfile_path):
hostname, slot_count = line.split(":") hostname, slot_count = line.split(":")
slot_count = int(slot_count) slot_count = int(slot_count)
except ValueError as err: except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to " click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
"proceed with training.") exit()
raise err
device_pool[hostname] = slot_count
if hostname in device_pool:
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
exit()
device_pool[hostname] = slot_count
return device_pool return device_pool
@ -48,7 +46,7 @@ def _stable_remove_duplicates(data):
return new_list return new_list
def parse_device_filter(host_info, include_str="", exclude_str=""): def parse_device_filter(host_info, include_str=None, exclude_str=None):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary. '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples: Examples:
@ -58,26 +56,25 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
slot 0 on worker-1. slot 0 on worker-1.
''' '''
logger = get_dist_logger()
# Constants that define our syntax # Constants that define our syntax
NODE_SEP = '@' NODE_SEP = '@'
SLOT_LIST_START = ':' SLOT_LIST_START = ':'
SLOT_SEP = ',' SLOT_SEP = ','
# Ensure include/exclude are mutually exclusive # Ensure include/exclude are mutually exclusive
if (include_str != "") and (exclude_str != ""): if include_str and exclude_str:
raise ValueError('include_str and exclude_str are mutually exclusive.') click.echo("--include and --exclude are mutually exclusive, only one can be used")
exit()
# no-op # no-op
if (include_str == "") and (exclude_str == ""): if include_str is None and exclude_str is None:
return host_info return host_info
# Either build from scratch or remove items # Either build from scratch or remove items
filtered_hosts = dict() filtered_hosts = dict()
if include_str: if include_str:
parse_str = include_str parse_str = include_str
if exclude_str != "": elif exclude_str:
filtered_hosts = deepcopy(host_info) filtered_hosts = deepcopy(host_info)
parse_str = exclude_str parse_str = exclude_str
@ -90,17 +87,18 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
# sanity checks # sanity checks
if hostname not in host_info: if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile") click.echo(f"Hostname '{hostname}' not found in hostfile")
exit()
for slot in slots: for slot in slots:
if slot not in host_info[hostname]: if slot not in host_info[hostname]:
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'") click.echo(f"No slot '{slot}' specified on host '{hostname}'")
# If include string, build the list from here # If include string, build the list from here
if include_str: if include_str:
filtered_hosts[hostname] = slots filtered_hosts[hostname] = slots
elif exclude_str: elif exclude_str:
for slot in slots: for slot in slots:
logger.info(f'removing {slot} from {hostname}') click.echo(f'- removing {slot} from {hostname}')
filtered_hosts[hostname].remove(slot) filtered_hosts[hostname].remove(slot)
# User just specified the whole node # User just specified the whole node
@ -108,7 +106,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
hostname = node_config hostname = node_config
# sanity check hostname # sanity check hostname
if hostname not in host_info: if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile") click.echo(f"Hostname '{hostname}' not found in hostfile")
exit()
if include_str: if include_str:
filtered_hosts[hostname] = host_info[hostname] filtered_hosts[hostname] = host_info[hostname]
@ -123,6 +122,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
# Remove empty hosts # Remove empty hosts
if len(filtered_hosts[hostname]) == 0: if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname) del_keys.append(hostname)
# remove unneeded hosts
for name in del_keys: for name in del_keys:
del filtered_hosts[name] del filtered_hosts[name]
@ -145,18 +146,40 @@ def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
def launch_multi_processes(args): def launch_multi_processes(args):
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}' """
Launch multiple processes on a single node or multiple nodes.
# check The overall logic can be summarized as the pseudo code below:
if hostfile given:
hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo)
launch_on_multi_nodes(hostinfo)
elif hosts given:
hostinfo = parse_hosts(hosts)
launch_on_multi_nodes(hostinfo)
else:
launch_on_current_node()
"""
assert isinstance(args, Config)
# cannot accept hosts and hostfile at the same time
if args.host and args.hostfile:
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
# check if hostfile is given
if args.hostfile: if args.hostfile:
device_pool = fetch_hostfile(args.hostfile) device_pool = fetch_hostfile(args.hostfile)
else: else:
device_pool = None device_pool = None
# filter and only keep the ones needed
active_devices = None active_devices = None
if device_pool: if device_pool:
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude) active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
if args.num_nodes > 0: if args.num_nodes > 0:
# only keep the first num_nodes to execute jobs
updated_active_devices = collections.OrderedDict() updated_active_devices = collections.OrderedDict()
for count, hostname in enumerate(active_devices.keys()): for count, hostname in enumerate(active_devices.keys()):
if args.num_nodes == count: if args.num_nodes == count:
@ -164,20 +187,36 @@ def launch_multi_processes(args):
updated_active_devices[hostname] = active_devices[hostname] updated_active_devices[hostname] = active_devices[hostname]
active_devices = updated_active_devices active_devices = updated_active_devices
if args.num_gpus > 0: if args.nproc_per_node > 0:
# only keep the first
updated_active_devices = collections.OrderedDict() updated_active_devices = collections.OrderedDict()
for hostname in active_devices.keys(): for hostname, active_devices in active_devices.items():
updated_active_devices[hostname] = list(range(args.num_gpus)) if len(active_devices) < args.nproc_per_node:
click.echo(
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
)
exit()
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
active_devices = updated_active_devices active_devices = updated_active_devices
env = os.environ.copy() env = os.environ.copy()
# use hosts if hostfile is not given
if args.host and active_devices is None:
hostinfo = collections.OrderedDict()
host_list = args.host.strip().split(',')
for hostname in host_list:
hostinfo[hostname] = args.nproc_per_node
active_devices = hostinfo
# run on local node if not hosts or hostfile is given
if not active_devices: if not active_devices:
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count(): if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count():
nproc_per_node = torch.cuda.device_count() nproc_per_node = torch.cuda.device_count()
else: else:
nproc_per_node = args.num_gpus nproc_per_node = args.nproc_per_node
if torch.__version__ <= "1.09":
if torch.__version__ <= "1.9":
cmd = [ cmd = [
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}" f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
@ -188,17 +227,7 @@ def launch_multi_processes(args):
f"--master_port={args.master_port}" f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args ] + [args.user_script] + args.user_args
else: else:
if args.launcher == "torch": runner = PDSHRunner(args)
runner = PDSHRunner(args)
elif args.launcher == "mpi":
runner = OpenMPIRunner(args, device_pool)
elif args.launcher == "slurm":
runner = SLURMRunner(args, device_pool)
else:
raise NotImplementedError(f"Unknown launcher {args.launcher}")
if not runner.backend_exists():
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
curr_path = os.path.abspath('.') curr_path = os.path.abspath('.')
if 'PYTHONPATH' in env: if 'PYTHONPATH' in env: