[CLI] refactored the launch CLI and fixed bugs in multi-node launching (#844)

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* [cli] fixed multi-node job launching

* [cli] fixed a bug in version comparison

* [cli] support launching with env var

* added docstring

* [cli] added extra launch arguments

* [cli] added default launch rdzv args

* [cli] fixed version comparison

* [cli] added docstring examples and requierment

* polish docstring

* polish code

* polish code
pull/843/head^2
Frank Lee 2022-04-24 13:26:26 +08:00 committed by GitHub
parent e5ea3fdeef
commit cf6d1c9284
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 468 additions and 240 deletions

View File

@ -5,27 +5,34 @@ from colossalai.context import Config
@click.command(help="Launch distributed training on a single node or multiple nodes", @click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True)) context_settings=dict(ignore_unknown_options=True))
@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch") @click.option("-H",
@click.option("--hostfile", "-host",
"--host",
type=str, type=str,
default=None, default=None,
help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)") help="the list of hostnames to launch in the format <host1>,<host2>")
@click.option( @click.option(
"--include", "--hostfile",
type=str, type=str,
default=None, default=None,
help= help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
"Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=<worker-name>:<list-of-slots>" @click.option("--include",
) type=str,
default=None,
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
" only effective when used with --hostfile.")
@click.option( @click.option(
"--exclude", "--exclude",
type=str, type=str,
default=None, default=None,
help= help=
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include." "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
) " only effective when used with --hostfile.")
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.") @click.option("--num_nodes",
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.") type=int,
default=-1,
help="Total number of worker nodes to use, only effective when used with --hostfile.")
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
@click.option("--master_port", @click.option("--master_port",
type=int, type=int,
default=29500, default=29500,
@ -35,34 +42,43 @@ from colossalai.context import Config
default="127.0.0.1", default="127.0.0.1",
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
@click.option( @click.option(
"--launcher", "--extra_launch_args",
type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False), type=str,
default="torch", default=None,
help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.") help=
@click.option("--launcher_args", "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
type=str, "This will be converted to --arg1=1 --arg2=2 during execution")
default=None, @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
help="(optional) pass launcher specific arguments as a single quoted argument.")
@click.argument("user_script", type=str) @click.argument("user_script", type=str)
@click.argument('user_args', nargs=-1) @click.argument('user_args', nargs=-1)
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str): master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
""" """
To launch multiple processes on a single node or multiple nodes via command line. To launch multiple processes on a single node or multiple nodes via command line.
Usage:: Usage::
# run on the current node with all available GPUs # run with 4 GPUs on the current node use default port 29500
colossalai run train.py colossalai run --nprocs_per_node 4 train.py
# run with only 2 GPUs on the current node # run with 2 GPUs on the current node at port 29550
colossalai run --nprocs_per_node 2 train.py colossalai run --nprocs_per_node 4 --master_port 29550 train.py
# run on two nodes # run on two nodes
colossalai run --host <host1>,<host2> train.py colossalai run --host <host1>,<host2> --master_addr host1 --nprocs_per_node 4 train.py
# run with hostfile # run with hostfile
colossalai run --hostfile <file_path> train.py colossalai run --hostfile <file_path> --master_addr <host> --nprocs_per_node 4 train.py
# run with hostfile with only included hosts
colossalai run --hostfile <file_path> --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
""" """
if not user_script.endswith('.py'):
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
exit()
args_dict = locals() args_dict = locals()
args = Config(args_dict) args = Config(args_dict)
args.user_args = list(args.user_args) args.user_args = list(args.user_args)

View File

@ -0,0 +1,122 @@
from typing import List
import socket
class HostInfo:
"""
A data class to store host connection-related data.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
"""
def __init__(
self,
hostname: str,
port: str = None,
):
self.hostname = hostname
self.port = port
self.is_local_host = HostInfo.is_host_localhost(hostname, port)
@staticmethod
def is_host_localhost(hostname: str, port: str = None) -> None:
"""
Check if the host refers to the local machine.
Args:
hostname (str): name or IP address of the host
port (str): the port for ssh connection
Returns:
bool: True if it is local, False otherwise
"""
if port is None:
port = 22 # no port specified, lets just use the ssh port
hostname = socket.getfqdn(hostname)
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
for (family, socktype, proto, canonname, sockaddr) in localaddrs:
for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
if rsockaddr[0] == sockaddr[0]:
return True
return False
def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'
def __repr__(self):
return self.__str__()
class HostInfoList:
"""
A data class to store a list of HostInfo objects.
"""
def __init__(self):
self.hostinfo_list = []
def append(self, hostinfo: HostInfo) -> None:
"""
Add an HostInfo object to the list.
Args:
hostinfo (HostInfo): host information
"""
self.hostinfo_list.append(hostinfo)
def remove(self, hostname: str) -> None:
"""
Add an HostInfo object to the list.
Args:
hostname (str): the name of the host
"""
hostinfo = self.get_hostinfo(hostname)
self.hostinfo_list.remove(hostinfo)
def get_hostinfo(self, hostname: str) -> HostInfo:
"""
Return the HostInfo object which matches with the hostname.
Args:
hostname (str): the name of the host
Returns:
hostinfo (HostInfo): the HostInfo object which matches with the hostname
"""
for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return hostinfo
raise Exception(f"Hostname {hostname} is not found")
def has(self, hostname: str) -> bool:
"""
Check if the hostname has been added.
Args:
hostname (str): the name of the host
Returns:
bool: True if added, False otherwise
"""
for hostinfo in self.hostinfo_list:
if hostinfo.hostname == hostname:
return True
return False
def __iter__(self):
return iter(self.hostinfo_list)
def __len__(self):
return len(self.hostinfo_list)

View File

@ -1,69 +1,120 @@
import os import fabric
import sys from fabric import Connection
import shutil from .hostinfo import HostInfo, HostInfoList
from shlex import quote from multiprocessing import Pipe, Process
from abc import ABC, abstractmethod from multiprocessing import connection as mp_connection
import click
from colossalai.logging import get_dist_logger
class MultiNodeRunner(ABC): def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
send_conn: mp_connection.Connection, env: dict) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
def __init__(self, args): Args:
self.args = args hostinfo (HostInfo): host information
self.user_arguments = self.args.user_args workdir (str): the directory to execute the command
self.user_script = args.user_script recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
self.exports = {} send_conn (multiprocessing.connection.Connection): send messages to the master receiver
env (dict): a dictionary for environment variables
"""
@abstractmethod fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
def backend_exists(self): finish = False
"""Return whether the corresponding backend exists""" env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
@abstractmethod # keep listening until exit
def get_cmd(self, environment, active_devices): while not finish:
"""Return the command to execute on node""" # receive cmd
cmds = recv_conn.recv()
def add_export(self, key, var): if cmds == 'exit':
self.exports[key.strip()] = var.strip() # exit from the loop
finish = True
break
else:
# execute the commands
try:
# cd to execute directory
with fab_conn.cd(workdir):
# propagate the runtime environment
with fab_conn.prefix(f"export {env_msg}"):
if hostinfo.is_local_host:
# execute on the local machine
fab_conn.local(cmds, hide=False)
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
except:
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
send_conn.send('failure')
@property # shutdown
def name(self): send_conn.send("finish")
"""Return the name of the backend""" fab_conn.close()
return self.__class__.__name__
class PDSHRunner(MultiNodeRunner): class MultiNodeRunner:
"""
A runner to execute commands on an array of machines. This runner
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
"""
def __init__(self, args): def __init__(self):
super().__init__(args) self.processes = {}
self.master_send_conns = {}
self.master_recv_conns = {}
def backend_exists(self): def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
return shutil.which('pdsh') """
Establish connections to a list of hosts
@property Args:
def name(self): host_info_list (HostInfoList): a list of HostInfo objects
return "pdsh" workdir (str): the directory where command is executed
env (dict): environment variables to propagate to hosts
"""
for hostinfo in host_info_list:
master_send_conn, worker_recv_conn = Pipe()
master_recv_conn, worker_send_conn = Pipe()
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
p.start()
self.processes[hostinfo.hostname] = p
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
self.master_send_conns[hostinfo.hostname] = master_send_conn
def parse_user_args(self): def send(self, hostinfo: HostInfo, cmd: str) -> None:
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) """
Send a command to a local/remote host.
def get_cmd(self, environment, active_devices, args): Args:
environment['PDSH_RCMD_TYPE'] = 'ssh' hostinfo (HostInfo): host information
cmd (str): the command to execute
"""
active_workers = ",".join(active_devices.keys()) assert hostinfo.hostname in self.master_send_conns, \
print("Running on the following workers: %s" % active_workers) f'{hostinfo} is not found in the current connections'
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers] def stop_all(self) -> None:
"""
Stop connections to all hosts.
"""
exports = "" for hostname, conn in self.master_send_conns.items():
for key, val in self.exports.items(): conn.send('exit')
exports += f"export {key}={quote(val)}; "
# https://linux.die.net/man/1/pdsh def recv_from_all(self) -> dict:
# %n will be replaced by pdsh command """
colossal_launch = [ Receive messages from all hosts
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}", Returns:
f"--master_port={args.master_port}" msg_from_node (dict): a dictionry which contains messages from each node
] """
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
msg_from_node = dict()
for hostname, conn in self.master_recv_conns.items():
msg_from_node[hostname] = conn.recv()
return msg_from_node

View File

@ -1,65 +1,72 @@
import click import click
import subprocess
import collections
import sys import sys
import os import os
import torch import torch
from colossalai.context import Config from colossalai.context import Config
from .multinode_runner import PDSHRunner from .multinode_runner import MultiNodeRunner
from copy import deepcopy from .hostinfo import HostInfo, HostInfoList
from typing import List
from packaging import version
# Constants that define our syntax
NODE_SEP = ','
def fetch_hostfile(hostfile_path): def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
"""
Parse the hostfile to obtain a list of hosts.
A hostfile should look like:
worker-0
worker-1
worker-2
...
Args:
hostfile_path (str): the path to the hostfile
ssh_port (int): the port to connect to the host
"""
if not os.path.isfile(hostfile_path): if not os.path.isfile(hostfile_path):
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit() exit()
# e.g., worker-0:16
with open(hostfile_path, 'r') as fd: with open(hostfile_path, 'r') as fd:
device_pool = collections.OrderedDict() device_pool = HostInfoList()
for line in fd.readlines(): for line in fd.readlines():
line = line.strip() line = line.strip()
if line == '': if line == '':
# skip empty lines # skip empty lines
continue continue
try:
hostname, slot_count = line.split(":")
slot_count = int(slot_count)
except ValueError as err:
click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
exit()
if hostname in device_pool: # build the HostInfo object
hostname = line.strip()
hostinfo = HostInfo(hostname=hostname, port=ssh_port)
if device_pool.has(hostname):
click.echo(f"Error: found duplicate host {hostname} in the hostfile") click.echo(f"Error: found duplicate host {hostname} in the hostfile")
exit() exit()
device_pool[hostname] = slot_count
device_pool.append(hostinfo)
return device_pool return device_pool
def _stable_remove_duplicates(data): def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
# Create a new list in the same order as original but with duplicates
# removed, should never be more than ~16 elements so simple is best
new_list = []
for x in data:
if x not in new_list:
new_list.append(x)
return new_list
def parse_device_filter(host_info, include_str=None, exclude_str=None):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary. '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples: Examples:
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
slots [0, 2] on worker-1. exclude_str="worker-1" will use all available devices except worker-1.
exclude_str="worker-1:0" will use all available devices except
slot 0 on worker-1.
'''
# Constants that define our syntax Args:
NODE_SEP = '@' device_pool (HostInfoList): a list of HostInfo objects
SLOT_LIST_START = ':' include_str (str): --include option passed by user, default None
SLOT_SEP = ',' exclude_str (str): --exclude option passed by user, default None
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
# Ensure include/exclude are mutually exclusive # Ensure include/exclude are mutually exclusive
if include_str and exclude_str: if include_str and exclude_str:
@ -68,176 +75,207 @@ def parse_device_filter(host_info, include_str=None, exclude_str=None):
# no-op # no-op
if include_str is None and exclude_str is None: if include_str is None and exclude_str is None:
return host_info return device_pool
# Either build from scratch or remove items # Either build from scratch or remove items
filtered_hosts = dict()
if include_str: if include_str:
parse_str = include_str parse_str = include_str
filtered_hosts = HostInfoList()
elif exclude_str: elif exclude_str:
filtered_hosts = deepcopy(host_info)
parse_str = exclude_str parse_str = exclude_str
filtered_hosts = device_pool
# foreach node in the list # foreach node in the list
for node_config in parse_str.split(NODE_SEP): for node_config in parse_str.split(NODE_SEP):
# Node can either be alone or node:slot,slot,slot hostname = node_config
if SLOT_LIST_START in node_config: hostinfo = device_pool.get_hostinfo(hostname)
hostname, slots = node_config.split(SLOT_LIST_START) # sanity check hostname
slots = [int(x) for x in slots.split(SLOT_SEP)] if not device_pool.has(hostname):
click.echo(f"Error: Hostname '{hostname}' not found in hostfile")
exit()
# sanity checks if include_str:
if hostname not in host_info: filtered_hosts.append(hostinfo)
click.echo(f"Hostname '{hostname}' not found in hostfile") elif exclude_str:
exit() filtered_hosts.remove(hostname)
for slot in slots:
if slot not in host_info[hostname]:
click.echo(f"No slot '{slot}' specified on host '{hostname}'")
# If include string, build the list from here return filtered_hosts
if include_str:
filtered_hosts[hostname] = slots
elif exclude_str:
for slot in slots:
click.echo(f'- removing {slot} from {hostname}')
filtered_hosts[hostname].remove(slot)
# User just specified the whole node
def get_launch_command(
master_addr: str,
master_port: int,
nproc_per_node: int,
user_script: str,
user_args: List[str],
node_rank: int,
num_nodes: int,
extra_launch_args: str = None,
) -> str:
"""
Generate a command for distributed training.
Args:
master_addr (str): the host of the master node
master_port (str): the port of the master node
nproc_per_node (str): the number of processes to launch on each node
user_script (str): the user Python file
user_args (str): the arguments for the user script
node_rank (int): the unique ID for the node
num_nodes (int): the number of nodes to execute jobs
Returns:
cmd (str): the command the start distributed training
"""
def _arg_dict_to_list(arg_dict):
ret = []
for k, v in arg_dict.items():
if v:
ret.append(f'--{k}={v}')
else:
ret.append(f'--{k}')
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
for arg in extra_launch_args.split(','):
if '=' in arg:
k, v = arg.split('=')
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
extra_launch_args = extra_launch_args_dict
else:
extra_launch_args = dict()
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
if torch_version.minor < 9:
cmd = [
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
f"--node_rank={node_rank}"
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
default_torchrun_rdzv_args = dict(rdzv_backend="c10d",
rdzv_endpoint=f"{master_addr}:{master_port}",
rdzv_id="colossalai-default-job")
# update rdzv arguments
for key in default_torchrun_rdzv_args.keys():
if key in extra_launch_args:
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
if torch_version.minor < 10:
cmd = [
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
]
else: else:
hostname = node_config cmd = [
# sanity check hostname "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
if hostname not in host_info: ]
click.echo(f"Hostname '{hostname}' not found in hostfile") cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
exit()
if include_str: cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
filtered_hosts[hostname] = host_info[hostname] cmd = ' '.join(cmd)
elif exclude_str: return cmd
filtered_hosts[hostname] = []
# Post-processing to remove duplicates and empty nodes
del_keys = []
for hostname in filtered_hosts:
# Remove duplicates
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
# Remove empty hosts
if len(filtered_hosts[hostname]) == 0:
del_keys.append(hostname)
# remove unneeded hosts
for name in del_keys:
del filtered_hosts[name]
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
# we map ranks to nodes correctly by maintaining host_info ordering.
ordered_hosts = collections.OrderedDict()
for host in host_info:
if host in filtered_hosts:
ordered_hosts[host] = filtered_hosts[host]
return ordered_hosts
def parse_inclusion_exclusion(device_pool, inclusion, exclusion): def launch_multi_processes(args: Config) -> None:
active_devices = collections.OrderedDict()
for hostname, slots in device_pool.items():
active_devices[hostname] = list(range(slots))
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
def launch_multi_processes(args):
""" """
Launch multiple processes on a single node or multiple nodes. Launch multiple processes on a single node or multiple nodes.
The overall logic can be summarized as the pseudo code below: The overall logic can be summarized as the pseudo code below:
if hostfile given: if hostfile given:
hostinfo = parse_hostfile(hostfile) hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo) hostinfo = include_or_exclude_hosts(hostinfo)
launch_on_multi_nodes(hostinfo) launch_on_multi_nodes(hostinfo)
elif hosts given: elif hosts given:
hostinfo = parse_hosts(hosts) hostinfo = parse_hosts(hosts)
launch_on_multi_nodes(hostinfo) launch_on_multi_nodes(hostinfo)
else: else:
launch_on_current_node() launch_on_current_node()
Args:
args (Config): the arguments taken from command line
""" """
assert isinstance(args, Config) assert isinstance(args, Config)
if args.nproc_per_node is None:
click.echo("--nproc_per_node did not receive any value")
exit()
# cannot accept hosts and hostfile at the same time # cannot accept hosts and hostfile at the same time
if args.host and args.hostfile: if args.host and args.hostfile:
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
# check if hostfile is given # check if hostfile is given
if args.hostfile: if args.hostfile:
device_pool = fetch_hostfile(args.hostfile) device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port)
else: active_device_pool = parse_device_filter(device_pool, args.include, args.exclude)
device_pool = None
# filter and only keep the ones needed
active_devices = None
if device_pool:
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
if args.num_nodes > 0: if args.num_nodes > 0:
# only keep the first num_nodes to execute jobs # only keep the first num_nodes to execute jobs
updated_active_devices = collections.OrderedDict() updated_active_device_pool = HostInfoList()
for count, hostname in enumerate(active_devices.keys()): for count, hostinfo in enumerate(active_device_pool):
if args.num_nodes == count: if args.num_nodes == count:
break break
updated_active_devices[hostname] = active_devices[hostname] updated_active_device_pool.append(hostinfo)
active_devices = updated_active_devices active_device_pool = updated_active_device_pool
else:
if args.nproc_per_node > 0: active_device_pool = None
# only keep the first
updated_active_devices = collections.OrderedDict()
for hostname, active_devices in active_devices.items():
if len(active_devices) < args.nproc_per_node:
click.echo(
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
)
exit()
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
active_devices = updated_active_devices
env = os.environ.copy() env = os.environ.copy()
# use hosts if hostfile is not given # use hosts if hostfile is not given
if args.host and active_devices is None: if args.host and active_device_pool is None:
hostinfo = collections.OrderedDict() active_device_pool = HostInfoList()
host_list = args.host.strip().split(',') host_list = args.host.strip().split(NODE_SEP)
for hostname in host_list: for hostname in host_list:
hostinfo[hostname] = args.nproc_per_node hostinfo = HostInfo(hostname=hostname, port=args.ssh_port)
active_devices = hostinfo active_device_pool.append(hostinfo)
# run on local node if not hosts or hostfile is given if not active_device_pool:
if not active_devices: # run on local node if not hosts or hostfile is given
if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count(): # add local node to host info list
nproc_per_node = torch.cuda.device_count() active_device_pool = HostInfoList()
else: localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
nproc_per_node = args.nproc_per_node active_device_pool.append(localhost_info)
if torch.__version__ <= "1.9": # launch distributed processes
cmd = [ runner = MultiNodeRunner()
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", curr_path = os.path.abspath('.')
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args
else:
cmd = [
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
] + [args.user_script] + args.user_args
else:
runner = PDSHRunner(args)
curr_path = os.path.abspath('.') # collect current path env
if 'PYTHONPATH' in env: env = dict()
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] for k, v in os.environ.items():
else: # do not support multi-line env var
env['PYTHONPATH'] = curr_path if v and '\n' not in v:
env[k] = v
cmd = runner.get_cmd(env, active_devices, args) # establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
result = subprocess.Popen(cmd, env=env) # execute distributed launching command
result.wait() for node_id, hostinfo in enumerate(active_device_pool):
if result.returncode > 0: cmd = get_launch_command(master_addr=args.master_addr,
sys.exit(result.returncode) master_port=args.master_port,
nproc_per_node=args.nproc_per_node,
user_script=args.user_script,
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
extra_launch_args=args.extra_launch_args)
runner.send(hostinfo=hostinfo, cmd=cmd)
runner.recv_from_all()
runner.stop_all()
runner.recv_from_all()

View File

@ -6,3 +6,4 @@ packaging
pre-commit pre-commit
rich rich
click click
fabric