mirror of https://github.com/hpcaitech/ColossalAI
[cli] fixed single-node process launching
parent
61c20b44bc
commit
d522cb704e
|
@ -25,7 +25,7 @@ from colossalai.context import Config
|
|||
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
|
||||
)
|
||||
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
|
||||
@click.option("--nprocs_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
|
||||
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
|
||||
@click.option("--master_port",
|
||||
type=int,
|
||||
default=29500,
|
||||
|
@ -45,7 +45,7 @@ from colossalai.context import Config
|
|||
help="(optional) pass launcher specific arguments as a single quoted argument.")
|
||||
@click.argument("user_script", type=str)
|
||||
@click.argument('user_args', nargs=-1)
|
||||
def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include: str, exclude: str, master_addr: str,
|
||||
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
|
||||
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
|
||||
"""
|
||||
To launch multiple processes on a single node or multiple nodes via command line.
|
||||
|
@ -66,5 +66,4 @@ def run(host: str, hostfile: str, num_nodes: int, nprocs_per_node: int, include:
|
|||
args_dict = locals()
|
||||
args = Config(args_dict)
|
||||
args.user_args = list(args.user_args)
|
||||
# (lsg) TODO: fix this function
|
||||
# launch_multi_processes(args)
|
||||
launch_multi_processes(args)
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from shlex import quote
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
class MultiNodeRunner(ABC):
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.user_arguments = self.args.user_args
|
||||
|
@ -35,6 +33,7 @@ class MultiNodeRunner(ABC):
|
|||
|
||||
|
||||
class PDSHRunner(MultiNodeRunner):
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
|
@ -46,15 +45,13 @@ class PDSHRunner(MultiNodeRunner):
|
|||
return "pdsh"
|
||||
|
||||
def parse_user_args(self):
|
||||
return list(
|
||||
map(lambda x: x if x.startswith("-") else f"'{x}'",
|
||||
self.args.user_args))
|
||||
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
|
||||
|
||||
def get_cmd(self, environment, active_devices, args):
|
||||
environment['PDSH_RCMD_TYPE'] = 'ssh'
|
||||
|
||||
active_workers = ",".join(active_devices.keys())
|
||||
logger.info("Running on the following workers: %s" % active_workers)
|
||||
print("Running on the following workers: %s" % active_workers)
|
||||
|
||||
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
|
||||
|
||||
|
@ -65,82 +62,8 @@ class PDSHRunner(MultiNodeRunner):
|
|||
# https://linux.die.net/man/1/pdsh
|
||||
# %n will be replaced by pdsh command
|
||||
colossal_launch = [
|
||||
exports,
|
||||
f"cd {os.path.abspath('.')};",
|
||||
sys.executable, "-u", "-m",
|
||||
"torch.distributed.launch",
|
||||
f"--nproc_per_node={args.num_gpus}",
|
||||
f"--master_addr={args.master_addr}",
|
||||
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
|
||||
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
|
||||
f"--master_port={args.master_port}"
|
||||
]
|
||||
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
|
||||
|
||||
class OpenMPIRunner(MultiNodeRunner):
|
||||
def __init__(self, args, device_pool):
|
||||
super().__init__(args)
|
||||
self.device_pool = device_pool
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('ompi_info')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "openmpi"
|
||||
|
||||
def get_cmd(self, environment, active_devices):
|
||||
total_process_count = sum(self.device_pool.values())
|
||||
|
||||
mpirun_cmd = [
|
||||
'mpirun',
|
||||
'-n',
|
||||
f'{total_process_count}',
|
||||
'-hostfile',
|
||||
f'{self.args.hostfile}'
|
||||
]
|
||||
|
||||
export_cmd = []
|
||||
for k, v in self.exports.items():
|
||||
export_cmd += ['-x', f'{k}={quote(v)}']
|
||||
|
||||
python_exec = []
|
||||
python_exec = [sys.executable, "-u", "-m"]
|
||||
|
||||
return mpirun_cmd + export_cmd + python_exec + [self.user_script
|
||||
] + self.user_arguments
|
||||
|
||||
class SLURMRunner(MultiNodeRunner):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def backend_exists(self):
|
||||
return shutil.which('slurm_info')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "slurm"
|
||||
|
||||
def get_cmd(self, environment, active_devices, args):
|
||||
|
||||
assert "-p" in args.launcher_args
|
||||
srun_args = args.launcher_args.strip().split()
|
||||
assert len(srun_args) >= 2, "we need more info about which partition to use."
|
||||
partition_name = srun_args(srun_args.index("-p")+1)
|
||||
slurm_cmd = [
|
||||
'srun',
|
||||
"-p",
|
||||
f"{partition_name}",
|
||||
"--nodes",
|
||||
f"{args.num_nodes}",
|
||||
"--tasks",
|
||||
f"{args.num_gpus}"
|
||||
]
|
||||
|
||||
export_cmd = []
|
||||
for k, v in self.exports.items():
|
||||
export_cmd += ['-x', f'{k}={quote(v)}']
|
||||
|
||||
python_exec = []
|
||||
python_exec = [sys.executable, "-u", "-m"]
|
||||
|
||||
return slurm_cmd + export_cmd + python_exec + [self.user_script
|
||||
] + self.user_arguments
|
||||
|
|
|
@ -1,22 +1,18 @@
|
|||
import argparse
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
import click
|
||||
import subprocess
|
||||
import collections
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.context import Config
|
||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
||||
from .multinode_runner import PDSHRunner
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def fetch_hostfile(hostfile_path):
|
||||
logger = get_dist_logger()
|
||||
if not os.path.isfile(hostfile_path):
|
||||
logger.warning("Unable to find hostfile, will proceed with training "
|
||||
"with local resources only")
|
||||
return None
|
||||
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
||||
exit()
|
||||
|
||||
# e.g., worker-0:16
|
||||
with open(hostfile_path, 'r') as fd:
|
||||
|
@ -30,11 +26,13 @@ def fetch_hostfile(hostfile_path):
|
|||
hostname, slot_count = line.split(":")
|
||||
slot_count = int(slot_count)
|
||||
except ValueError as err:
|
||||
logger.error("Hostfile is not formatted correctly, unable to "
|
||||
"proceed with training.")
|
||||
raise err
|
||||
device_pool[hostname] = slot_count
|
||||
click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
|
||||
exit()
|
||||
|
||||
if hostname in device_pool:
|
||||
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
|
||||
exit()
|
||||
device_pool[hostname] = slot_count
|
||||
return device_pool
|
||||
|
||||
|
||||
|
@ -48,7 +46,7 @@ def _stable_remove_duplicates(data):
|
|||
return new_list
|
||||
|
||||
|
||||
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
||||
def parse_device_filter(host_info, include_str=None, exclude_str=None):
|
||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||
|
||||
Examples:
|
||||
|
@ -58,26 +56,25 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
|||
slot 0 on worker-1.
|
||||
'''
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Constants that define our syntax
|
||||
NODE_SEP = '@'
|
||||
SLOT_LIST_START = ':'
|
||||
SLOT_SEP = ','
|
||||
|
||||
# Ensure include/exclude are mutually exclusive
|
||||
if (include_str != "") and (exclude_str != ""):
|
||||
raise ValueError('include_str and exclude_str are mutually exclusive.')
|
||||
if include_str and exclude_str:
|
||||
click.echo("--include and --exclude are mutually exclusive, only one can be used")
|
||||
exit()
|
||||
|
||||
# no-op
|
||||
if (include_str == "") and (exclude_str == ""):
|
||||
if include_str is None and exclude_str is None:
|
||||
return host_info
|
||||
|
||||
# Either build from scratch or remove items
|
||||
filtered_hosts = dict()
|
||||
if include_str:
|
||||
parse_str = include_str
|
||||
if exclude_str != "":
|
||||
elif exclude_str:
|
||||
filtered_hosts = deepcopy(host_info)
|
||||
parse_str = exclude_str
|
||||
|
||||
|
@ -90,17 +87,18 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
|||
|
||||
# sanity checks
|
||||
if hostname not in host_info:
|
||||
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
||||
click.echo(f"Hostname '{hostname}' not found in hostfile")
|
||||
exit()
|
||||
for slot in slots:
|
||||
if slot not in host_info[hostname]:
|
||||
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
|
||||
click.echo(f"No slot '{slot}' specified on host '{hostname}'")
|
||||
|
||||
# If include string, build the list from here
|
||||
if include_str:
|
||||
filtered_hosts[hostname] = slots
|
||||
elif exclude_str:
|
||||
for slot in slots:
|
||||
logger.info(f'removing {slot} from {hostname}')
|
||||
click.echo(f'- removing {slot} from {hostname}')
|
||||
filtered_hosts[hostname].remove(slot)
|
||||
|
||||
# User just specified the whole node
|
||||
|
@ -108,7 +106,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
|||
hostname = node_config
|
||||
# sanity check hostname
|
||||
if hostname not in host_info:
|
||||
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
||||
click.echo(f"Hostname '{hostname}' not found in hostfile")
|
||||
exit()
|
||||
|
||||
if include_str:
|
||||
filtered_hosts[hostname] = host_info[hostname]
|
||||
|
@ -123,6 +122,8 @@ def parse_device_filter(host_info, include_str="", exclude_str=""):
|
|||
# Remove empty hosts
|
||||
if len(filtered_hosts[hostname]) == 0:
|
||||
del_keys.append(hostname)
|
||||
|
||||
# remove unneeded hosts
|
||||
for name in del_keys:
|
||||
del filtered_hosts[name]
|
||||
|
||||
|
@ -145,18 +146,40 @@ def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
|||
|
||||
|
||||
def launch_multi_processes(args):
|
||||
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
|
||||
"""
|
||||
Launch multiple processes on a single node or multiple nodes.
|
||||
|
||||
# check
|
||||
The overall logic can be summarized as the pseudo code below:
|
||||
|
||||
if hostfile given:
|
||||
hostinfo = parse_hostfile(hostfile)
|
||||
hostinfo = include_or_exclude_hosts(hostinfo)
|
||||
launch_on_multi_nodes(hostinfo)
|
||||
elif hosts given:
|
||||
hostinfo = parse_hosts(hosts)
|
||||
launch_on_multi_nodes(hostinfo)
|
||||
else:
|
||||
launch_on_current_node()
|
||||
"""
|
||||
assert isinstance(args, Config)
|
||||
|
||||
# cannot accept hosts and hostfile at the same time
|
||||
if args.host and args.hostfile:
|
||||
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
|
||||
|
||||
# check if hostfile is given
|
||||
if args.hostfile:
|
||||
device_pool = fetch_hostfile(args.hostfile)
|
||||
else:
|
||||
device_pool = None
|
||||
|
||||
# filter and only keep the ones needed
|
||||
active_devices = None
|
||||
if device_pool:
|
||||
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
|
||||
|
||||
if args.num_nodes > 0:
|
||||
# only keep the first num_nodes to execute jobs
|
||||
updated_active_devices = collections.OrderedDict()
|
||||
for count, hostname in enumerate(active_devices.keys()):
|
||||
if args.num_nodes == count:
|
||||
|
@ -164,20 +187,36 @@ def launch_multi_processes(args):
|
|||
updated_active_devices[hostname] = active_devices[hostname]
|
||||
active_devices = updated_active_devices
|
||||
|
||||
if args.num_gpus > 0:
|
||||
if args.nproc_per_node > 0:
|
||||
# only keep the first
|
||||
updated_active_devices = collections.OrderedDict()
|
||||
for hostname in active_devices.keys():
|
||||
updated_active_devices[hostname] = list(range(args.num_gpus))
|
||||
for hostname, active_devices in active_devices.items():
|
||||
if len(active_devices) < args.nproc_per_node:
|
||||
click.echo(
|
||||
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
|
||||
)
|
||||
exit()
|
||||
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
|
||||
active_devices = updated_active_devices
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
# use hosts if hostfile is not given
|
||||
if args.host and active_devices is None:
|
||||
hostinfo = collections.OrderedDict()
|
||||
host_list = args.host.strip().split(',')
|
||||
for hostname in host_list:
|
||||
hostinfo[hostname] = args.nproc_per_node
|
||||
active_devices = hostinfo
|
||||
|
||||
# run on local node if not hosts or hostfile is given
|
||||
if not active_devices:
|
||||
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
|
||||
if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count():
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
else:
|
||||
nproc_per_node = args.num_gpus
|
||||
if torch.__version__ <= "1.09":
|
||||
nproc_per_node = args.nproc_per_node
|
||||
|
||||
if torch.__version__ <= "1.9":
|
||||
cmd = [
|
||||
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
|
||||
|
@ -188,17 +227,7 @@ def launch_multi_processes(args):
|
|||
f"--master_port={args.master_port}"
|
||||
] + [args.user_script] + args.user_args
|
||||
else:
|
||||
if args.launcher == "torch":
|
||||
runner = PDSHRunner(args)
|
||||
elif args.launcher == "mpi":
|
||||
runner = OpenMPIRunner(args, device_pool)
|
||||
elif args.launcher == "slurm":
|
||||
runner = SLURMRunner(args, device_pool)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown launcher {args.launcher}")
|
||||
|
||||
if not runner.backend_exists():
|
||||
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
|
||||
runner = PDSHRunner(args)
|
||||
|
||||
curr_path = os.path.abspath('.')
|
||||
if 'PYTHONPATH' in env:
|
||||
|
|
Loading…
Reference in New Issue