mirror of https://github.com/hpcaitech/ColossalAI
[CLI] refactored the launch CLI and fixed bugs in multi-node launching (#844)
* [cli] fixed multi-node job launching * [cli] fixed a bug in version comparison * [cli] support launching with env var * [cli] fixed multi-node job launching * [cli] fixed a bug in version comparison * [cli] support launching with env var * added docstring * [cli] added extra launch arguments * [cli] added default launch rdzv args * [cli] fixed version comparison * [cli] added docstring examples and requierment * polish docstring * polish code * polish codepull/843/head^2
parent
e5ea3fdeef
commit
cf6d1c9284
|
@ -5,27 +5,34 @@ from colossalai.context import Config
|
||||||
|
|
||||||
@click.command(help="Launch distributed training on a single node or multiple nodes",
|
@click.command(help="Launch distributed training on a single node or multiple nodes",
|
||||||
context_settings=dict(ignore_unknown_options=True))
|
context_settings=dict(ignore_unknown_options=True))
|
||||||
@click.option("-H", "-host", "--host", type=str, default=None, help="the list of machines to launch")
|
@click.option("-H",
|
||||||
@click.option("--hostfile",
|
"-host",
|
||||||
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Hostfile path that defines the device pool available to the job (e.g. worker-name:number of slots)")
|
help="the list of hostnames to launch in the format <host1>,<host2>")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--include",
|
"--hostfile",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=
|
help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
|
||||||
"Specify computing devices to use during execution. String format is NODE_SPEC@NODE_SPEC where NODE_SPEC=<worker-name>:<list-of-slots>"
|
@click.option("--include",
|
||||||
)
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specify computing devices to use during execution. String format is <host1>,<host2>,"
|
||||||
|
" only effective when used with --hostfile.")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--exclude",
|
"--exclude",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=
|
help=
|
||||||
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include."
|
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
|
||||||
)
|
" only effective when used with --hostfile.")
|
||||||
@click.option("--num_nodes", type=int, default=-1, help="Total number of worker nodes to use.")
|
@click.option("--num_nodes",
|
||||||
@click.option("--nproc_per_node", type=int, default=-1, help="Number of GPUs to use on each node.")
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Total number of worker nodes to use, only effective when used with --hostfile.")
|
||||||
|
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
|
||||||
@click.option("--master_port",
|
@click.option("--master_port",
|
||||||
type=int,
|
type=int,
|
||||||
default=29500,
|
default=29500,
|
||||||
|
@ -35,34 +42,43 @@ from colossalai.context import Config
|
||||||
default="127.0.0.1",
|
default="127.0.0.1",
|
||||||
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
|
help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--launcher",
|
"--extra_launch_args",
|
||||||
type=click.Choice(['torch', 'openmpi', 'slurm'], case_sensitive=False),
|
type=str,
|
||||||
default="torch",
|
default=None,
|
||||||
help="(optional) choose launcher backend for multi-node training. Options currently include PDSH, OpenMPI, SLURM.")
|
help=
|
||||||
@click.option("--launcher_args",
|
"Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
|
||||||
type=str,
|
"This will be converted to --arg1=1 --arg2=2 during execution")
|
||||||
default=None,
|
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
|
||||||
help="(optional) pass launcher specific arguments as a single quoted argument.")
|
|
||||||
@click.argument("user_script", type=str)
|
@click.argument("user_script", type=str)
|
||||||
@click.argument('user_args', nargs=-1)
|
@click.argument('user_args', nargs=-1)
|
||||||
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
|
def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
|
||||||
master_port: int, launcher: str, launcher_args: str, user_script: str, user_args: str):
|
master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
|
||||||
"""
|
"""
|
||||||
To launch multiple processes on a single node or multiple nodes via command line.
|
To launch multiple processes on a single node or multiple nodes via command line.
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
# run on the current node with all available GPUs
|
# run with 4 GPUs on the current node use default port 29500
|
||||||
colossalai run train.py
|
colossalai run --nprocs_per_node 4 train.py
|
||||||
|
|
||||||
# run with only 2 GPUs on the current node
|
# run with 2 GPUs on the current node at port 29550
|
||||||
colossalai run --nprocs_per_node 2 train.py
|
colossalai run --nprocs_per_node 4 --master_port 29550 train.py
|
||||||
|
|
||||||
# run on two nodes
|
# run on two nodes
|
||||||
colossalai run --host <host1>,<host2> train.py
|
colossalai run --host <host1>,<host2> --master_addr host1 --nprocs_per_node 4 train.py
|
||||||
|
|
||||||
# run with hostfile
|
# run with hostfile
|
||||||
colossalai run --hostfile <file_path> train.py
|
colossalai run --hostfile <file_path> --master_addr <host> --nprocs_per_node 4 train.py
|
||||||
|
|
||||||
|
# run with hostfile with only included hosts
|
||||||
|
colossalai run --hostfile <file_path> --master_addr host1 --include host1,host2 --nprocs_per_node 4 train.py
|
||||||
|
|
||||||
|
# run with hostfile excluding the hosts selected
|
||||||
|
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
|
||||||
"""
|
"""
|
||||||
|
if not user_script.endswith('.py'):
|
||||||
|
click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
|
||||||
|
exit()
|
||||||
|
|
||||||
args_dict = locals()
|
args_dict = locals()
|
||||||
args = Config(args_dict)
|
args = Config(args_dict)
|
||||||
args.user_args = list(args.user_args)
|
args.user_args = list(args.user_args)
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
from typing import List
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
class HostInfo:
|
||||||
|
"""
|
||||||
|
A data class to store host connection-related data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname (str): name or IP address of the host
|
||||||
|
port (str): the port for ssh connection
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hostname: str,
|
||||||
|
port: str = None,
|
||||||
|
):
|
||||||
|
self.hostname = hostname
|
||||||
|
self.port = port
|
||||||
|
self.is_local_host = HostInfo.is_host_localhost(hostname, port)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_host_localhost(hostname: str, port: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Check if the host refers to the local machine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname (str): name or IP address of the host
|
||||||
|
port (str): the port for ssh connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if it is local, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
if port is None:
|
||||||
|
port = 22 # no port specified, lets just use the ssh port
|
||||||
|
hostname = socket.getfqdn(hostname)
|
||||||
|
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
|
||||||
|
return True
|
||||||
|
localhost = socket.gethostname()
|
||||||
|
localaddrs = socket.getaddrinfo(localhost, port)
|
||||||
|
targetaddrs = socket.getaddrinfo(hostname, port)
|
||||||
|
for (family, socktype, proto, canonname, sockaddr) in localaddrs:
|
||||||
|
for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
|
||||||
|
if rsockaddr[0] == sockaddr[0]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'hostname: {self.hostname}, port: {self.port}'
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class HostInfoList:
|
||||||
|
"""
|
||||||
|
A data class to store a list of HostInfo objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.hostinfo_list = []
|
||||||
|
|
||||||
|
def append(self, hostinfo: HostInfo) -> None:
|
||||||
|
"""
|
||||||
|
Add an HostInfo object to the list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostinfo (HostInfo): host information
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.hostinfo_list.append(hostinfo)
|
||||||
|
|
||||||
|
def remove(self, hostname: str) -> None:
|
||||||
|
"""
|
||||||
|
Add an HostInfo object to the list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname (str): the name of the host
|
||||||
|
"""
|
||||||
|
|
||||||
|
hostinfo = self.get_hostinfo(hostname)
|
||||||
|
self.hostinfo_list.remove(hostinfo)
|
||||||
|
|
||||||
|
def get_hostinfo(self, hostname: str) -> HostInfo:
|
||||||
|
"""
|
||||||
|
Return the HostInfo object which matches with the hostname.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname (str): the name of the host
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hostinfo (HostInfo): the HostInfo object which matches with the hostname
|
||||||
|
"""
|
||||||
|
|
||||||
|
for hostinfo in self.hostinfo_list:
|
||||||
|
if hostinfo.hostname == hostname:
|
||||||
|
return hostinfo
|
||||||
|
|
||||||
|
raise Exception(f"Hostname {hostname} is not found")
|
||||||
|
|
||||||
|
def has(self, hostname: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the hostname has been added.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname (str): the name of the host
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if added, False otherwise
|
||||||
|
"""
|
||||||
|
for hostinfo in self.hostinfo_list:
|
||||||
|
if hostinfo.hostname == hostname:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.hostinfo_list)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.hostinfo_list)
|
|
@ -1,69 +1,120 @@
|
||||||
import os
|
import fabric
|
||||||
import sys
|
from fabric import Connection
|
||||||
import shutil
|
from .hostinfo import HostInfo, HostInfoList
|
||||||
from shlex import quote
|
from multiprocessing import Pipe, Process
|
||||||
from abc import ABC, abstractmethod
|
from multiprocessing import connection as mp_connection
|
||||||
|
import click
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
|
|
||||||
class MultiNodeRunner(ABC):
|
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
|
||||||
|
send_conn: mp_connection.Connection, env: dict) -> None:
|
||||||
|
"""
|
||||||
|
Use fabric connection to execute command on local or remote hosts.
|
||||||
|
|
||||||
def __init__(self, args):
|
Args:
|
||||||
self.args = args
|
hostinfo (HostInfo): host information
|
||||||
self.user_arguments = self.args.user_args
|
workdir (str): the directory to execute the command
|
||||||
self.user_script = args.user_script
|
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
|
||||||
self.exports = {}
|
send_conn (multiprocessing.connection.Connection): send messages to the master receiver
|
||||||
|
env (dict): a dictionary for environment variables
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
|
||||||
def backend_exists(self):
|
finish = False
|
||||||
"""Return whether the corresponding backend exists"""
|
env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
|
||||||
|
|
||||||
@abstractmethod
|
# keep listening until exit
|
||||||
def get_cmd(self, environment, active_devices):
|
while not finish:
|
||||||
"""Return the command to execute on node"""
|
# receive cmd
|
||||||
|
cmds = recv_conn.recv()
|
||||||
|
|
||||||
def add_export(self, key, var):
|
if cmds == 'exit':
|
||||||
self.exports[key.strip()] = var.strip()
|
# exit from the loop
|
||||||
|
finish = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# execute the commands
|
||||||
|
try:
|
||||||
|
# cd to execute directory
|
||||||
|
with fab_conn.cd(workdir):
|
||||||
|
# propagate the runtime environment
|
||||||
|
with fab_conn.prefix(f"export {env_msg}"):
|
||||||
|
if hostinfo.is_local_host:
|
||||||
|
# execute on the local machine
|
||||||
|
fab_conn.local(cmds, hide=False)
|
||||||
|
else:
|
||||||
|
# execute on the remote machine
|
||||||
|
fab_conn.run(cmds, hide=False)
|
||||||
|
send_conn.send('success')
|
||||||
|
except:
|
||||||
|
click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
|
||||||
|
send_conn.send('failure')
|
||||||
|
|
||||||
@property
|
# shutdown
|
||||||
def name(self):
|
send_conn.send("finish")
|
||||||
"""Return the name of the backend"""
|
fab_conn.close()
|
||||||
return self.__class__.__name__
|
|
||||||
|
|
||||||
|
|
||||||
class PDSHRunner(MultiNodeRunner):
|
class MultiNodeRunner:
|
||||||
|
"""
|
||||||
|
A runner to execute commands on an array of machines. This runner
|
||||||
|
is inspired by Nezha (https://github.com/zhuzilin/NeZha).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self):
|
||||||
super().__init__(args)
|
self.processes = {}
|
||||||
|
self.master_send_conns = {}
|
||||||
|
self.master_recv_conns = {}
|
||||||
|
|
||||||
def backend_exists(self):
|
def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
|
||||||
return shutil.which('pdsh')
|
"""
|
||||||
|
Establish connections to a list of hosts
|
||||||
|
|
||||||
@property
|
Args:
|
||||||
def name(self):
|
host_info_list (HostInfoList): a list of HostInfo objects
|
||||||
return "pdsh"
|
workdir (str): the directory where command is executed
|
||||||
|
env (dict): environment variables to propagate to hosts
|
||||||
|
"""
|
||||||
|
for hostinfo in host_info_list:
|
||||||
|
master_send_conn, worker_recv_conn = Pipe()
|
||||||
|
master_recv_conn, worker_send_conn = Pipe()
|
||||||
|
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
|
||||||
|
p.start()
|
||||||
|
self.processes[hostinfo.hostname] = p
|
||||||
|
self.master_recv_conns[hostinfo.hostname] = master_recv_conn
|
||||||
|
self.master_send_conns[hostinfo.hostname] = master_send_conn
|
||||||
|
|
||||||
def parse_user_args(self):
|
def send(self, hostinfo: HostInfo, cmd: str) -> None:
|
||||||
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
|
"""
|
||||||
|
Send a command to a local/remote host.
|
||||||
|
|
||||||
def get_cmd(self, environment, active_devices, args):
|
Args:
|
||||||
environment['PDSH_RCMD_TYPE'] = 'ssh'
|
hostinfo (HostInfo): host information
|
||||||
|
cmd (str): the command to execute
|
||||||
|
"""
|
||||||
|
|
||||||
active_workers = ",".join(active_devices.keys())
|
assert hostinfo.hostname in self.master_send_conns, \
|
||||||
print("Running on the following workers: %s" % active_workers)
|
f'{hostinfo} is not found in the current connections'
|
||||||
|
conn = self.master_send_conns[hostinfo.hostname]
|
||||||
|
conn.send(cmd)
|
||||||
|
|
||||||
pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers]
|
def stop_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Stop connections to all hosts.
|
||||||
|
"""
|
||||||
|
|
||||||
exports = ""
|
for hostname, conn in self.master_send_conns.items():
|
||||||
for key, val in self.exports.items():
|
conn.send('exit')
|
||||||
exports += f"export {key}={quote(val)}; "
|
|
||||||
|
|
||||||
# https://linux.die.net/man/1/pdsh
|
def recv_from_all(self) -> dict:
|
||||||
# %n will be replaced by pdsh command
|
"""
|
||||||
colossal_launch = [
|
Receive messages from all hosts
|
||||||
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "torch.distributed.launch",
|
|
||||||
f"--nproc_per_node={args.nproc_per_node}", f"--master_addr={args.master_addr}",
|
Returns:
|
||||||
f"--master_port={args.master_port}"
|
msg_from_node (dict): a dictionry which contains messages from each node
|
||||||
]
|
"""
|
||||||
return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments
|
|
||||||
|
msg_from_node = dict()
|
||||||
|
for hostname, conn in self.master_recv_conns.items():
|
||||||
|
msg_from_node[hostname] = conn.recv()
|
||||||
|
return msg_from_node
|
||||||
|
|
|
@ -1,65 +1,72 @@
|
||||||
import click
|
import click
|
||||||
import subprocess
|
|
||||||
import collections
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from colossalai.context import Config
|
from colossalai.context import Config
|
||||||
from .multinode_runner import PDSHRunner
|
from .multinode_runner import MultiNodeRunner
|
||||||
from copy import deepcopy
|
from .hostinfo import HostInfo, HostInfoList
|
||||||
|
from typing import List
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
# Constants that define our syntax
|
||||||
|
NODE_SEP = ','
|
||||||
|
|
||||||
|
|
||||||
def fetch_hostfile(hostfile_path):
|
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
||||||
|
"""
|
||||||
|
Parse the hostfile to obtain a list of hosts.
|
||||||
|
|
||||||
|
A hostfile should look like:
|
||||||
|
worker-0
|
||||||
|
worker-1
|
||||||
|
worker-2
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostfile_path (str): the path to the hostfile
|
||||||
|
ssh_port (int): the port to connect to the host
|
||||||
|
"""
|
||||||
|
|
||||||
if not os.path.isfile(hostfile_path):
|
if not os.path.isfile(hostfile_path):
|
||||||
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# e.g., worker-0:16
|
|
||||||
with open(hostfile_path, 'r') as fd:
|
with open(hostfile_path, 'r') as fd:
|
||||||
device_pool = collections.OrderedDict()
|
device_pool = HostInfoList()
|
||||||
|
|
||||||
for line in fd.readlines():
|
for line in fd.readlines():
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line == '':
|
if line == '':
|
||||||
# skip empty lines
|
# skip empty lines
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
hostname, slot_count = line.split(":")
|
|
||||||
slot_count = int(slot_count)
|
|
||||||
except ValueError as err:
|
|
||||||
click.echo(f"Error: Hostfile is not formatted correctly, expected <hostname>:<slot>, but found {line}")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
if hostname in device_pool:
|
# build the HostInfo object
|
||||||
|
hostname = line.strip()
|
||||||
|
hostinfo = HostInfo(hostname=hostname, port=ssh_port)
|
||||||
|
|
||||||
|
if device_pool.has(hostname):
|
||||||
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
|
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
|
||||||
exit()
|
exit()
|
||||||
device_pool[hostname] = slot_count
|
|
||||||
|
device_pool.append(hostinfo)
|
||||||
return device_pool
|
return device_pool
|
||||||
|
|
||||||
|
|
||||||
def _stable_remove_duplicates(data):
|
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
|
||||||
# Create a new list in the same order as original but with duplicates
|
|
||||||
# removed, should never be more than ~16 elements so simple is best
|
|
||||||
new_list = []
|
|
||||||
for x in data:
|
|
||||||
if x not in new_list:
|
|
||||||
new_list.append(x)
|
|
||||||
return new_list
|
|
||||||
|
|
||||||
|
|
||||||
def parse_device_filter(host_info, include_str=None, exclude_str=None):
|
|
||||||
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
|
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
|
||||||
slots [0, 2] on worker-1.
|
exclude_str="worker-1" will use all available devices except worker-1.
|
||||||
exclude_str="worker-1:0" will use all available devices except
|
|
||||||
slot 0 on worker-1.
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Constants that define our syntax
|
Args:
|
||||||
NODE_SEP = '@'
|
device_pool (HostInfoList): a list of HostInfo objects
|
||||||
SLOT_LIST_START = ':'
|
include_str (str): --include option passed by user, default None
|
||||||
SLOT_SEP = ','
|
exclude_str (str): --exclude option passed by user, default None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
||||||
|
'''
|
||||||
|
|
||||||
# Ensure include/exclude are mutually exclusive
|
# Ensure include/exclude are mutually exclusive
|
||||||
if include_str and exclude_str:
|
if include_str and exclude_str:
|
||||||
|
@ -68,176 +75,207 @@ def parse_device_filter(host_info, include_str=None, exclude_str=None):
|
||||||
|
|
||||||
# no-op
|
# no-op
|
||||||
if include_str is None and exclude_str is None:
|
if include_str is None and exclude_str is None:
|
||||||
return host_info
|
return device_pool
|
||||||
|
|
||||||
# Either build from scratch or remove items
|
# Either build from scratch or remove items
|
||||||
filtered_hosts = dict()
|
|
||||||
if include_str:
|
if include_str:
|
||||||
parse_str = include_str
|
parse_str = include_str
|
||||||
|
filtered_hosts = HostInfoList()
|
||||||
elif exclude_str:
|
elif exclude_str:
|
||||||
filtered_hosts = deepcopy(host_info)
|
|
||||||
parse_str = exclude_str
|
parse_str = exclude_str
|
||||||
|
filtered_hosts = device_pool
|
||||||
|
|
||||||
# foreach node in the list
|
# foreach node in the list
|
||||||
for node_config in parse_str.split(NODE_SEP):
|
for node_config in parse_str.split(NODE_SEP):
|
||||||
# Node can either be alone or node:slot,slot,slot
|
hostname = node_config
|
||||||
if SLOT_LIST_START in node_config:
|
hostinfo = device_pool.get_hostinfo(hostname)
|
||||||
hostname, slots = node_config.split(SLOT_LIST_START)
|
# sanity check hostname
|
||||||
slots = [int(x) for x in slots.split(SLOT_SEP)]
|
if not device_pool.has(hostname):
|
||||||
|
click.echo(f"Error: Hostname '{hostname}' not found in hostfile")
|
||||||
|
exit()
|
||||||
|
|
||||||
# sanity checks
|
if include_str:
|
||||||
if hostname not in host_info:
|
filtered_hosts.append(hostinfo)
|
||||||
click.echo(f"Hostname '{hostname}' not found in hostfile")
|
elif exclude_str:
|
||||||
exit()
|
filtered_hosts.remove(hostname)
|
||||||
for slot in slots:
|
|
||||||
if slot not in host_info[hostname]:
|
|
||||||
click.echo(f"No slot '{slot}' specified on host '{hostname}'")
|
|
||||||
|
|
||||||
# If include string, build the list from here
|
return filtered_hosts
|
||||||
if include_str:
|
|
||||||
filtered_hosts[hostname] = slots
|
|
||||||
elif exclude_str:
|
|
||||||
for slot in slots:
|
|
||||||
click.echo(f'- removing {slot} from {hostname}')
|
|
||||||
filtered_hosts[hostname].remove(slot)
|
|
||||||
|
|
||||||
# User just specified the whole node
|
|
||||||
|
def get_launch_command(
|
||||||
|
master_addr: str,
|
||||||
|
master_port: int,
|
||||||
|
nproc_per_node: int,
|
||||||
|
user_script: str,
|
||||||
|
user_args: List[str],
|
||||||
|
node_rank: int,
|
||||||
|
num_nodes: int,
|
||||||
|
extra_launch_args: str = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a command for distributed training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_addr (str): the host of the master node
|
||||||
|
master_port (str): the port of the master node
|
||||||
|
nproc_per_node (str): the number of processes to launch on each node
|
||||||
|
user_script (str): the user Python file
|
||||||
|
user_args (str): the arguments for the user script
|
||||||
|
node_rank (int): the unique ID for the node
|
||||||
|
num_nodes (int): the number of nodes to execute jobs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cmd (str): the command the start distributed training
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _arg_dict_to_list(arg_dict):
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
for k, v in arg_dict.items():
|
||||||
|
if v:
|
||||||
|
ret.append(f'--{k}={v}')
|
||||||
|
else:
|
||||||
|
ret.append(f'--{k}')
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if extra_launch_args:
|
||||||
|
extra_launch_args_dict = dict()
|
||||||
|
for arg in extra_launch_args.split(','):
|
||||||
|
if '=' in arg:
|
||||||
|
k, v = arg.split('=')
|
||||||
|
extra_launch_args_dict[k] = v
|
||||||
|
else:
|
||||||
|
extra_launch_args_dict[arg] = None
|
||||||
|
extra_launch_args = extra_launch_args_dict
|
||||||
|
else:
|
||||||
|
extra_launch_args = dict()
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
assert torch_version.major == 1
|
||||||
|
|
||||||
|
if torch_version.minor < 9:
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
||||||
|
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
|
||||||
|
f"--node_rank={node_rank}"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# extra launch args for torch distributed launcher with torch >= 1.9
|
||||||
|
default_torchrun_rdzv_args = dict(rdzv_backend="c10d",
|
||||||
|
rdzv_endpoint=f"{master_addr}:{master_port}",
|
||||||
|
rdzv_id="colossalai-default-job")
|
||||||
|
|
||||||
|
# update rdzv arguments
|
||||||
|
for key in default_torchrun_rdzv_args.keys():
|
||||||
|
if key in extra_launch_args:
|
||||||
|
value = extra_launch_args.pop(key)
|
||||||
|
default_torchrun_rdzv_args[key] = value
|
||||||
|
|
||||||
|
if torch_version.minor < 10:
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
|
||||||
|
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
hostname = node_config
|
cmd = [
|
||||||
# sanity check hostname
|
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
||||||
if hostname not in host_info:
|
]
|
||||||
click.echo(f"Hostname '{hostname}' not found in hostfile")
|
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
|
||||||
exit()
|
|
||||||
|
|
||||||
if include_str:
|
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
|
||||||
filtered_hosts[hostname] = host_info[hostname]
|
cmd = ' '.join(cmd)
|
||||||
elif exclude_str:
|
return cmd
|
||||||
filtered_hosts[hostname] = []
|
|
||||||
|
|
||||||
# Post-processing to remove duplicates and empty nodes
|
|
||||||
del_keys = []
|
|
||||||
for hostname in filtered_hosts:
|
|
||||||
# Remove duplicates
|
|
||||||
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
|
|
||||||
# Remove empty hosts
|
|
||||||
if len(filtered_hosts[hostname]) == 0:
|
|
||||||
del_keys.append(hostname)
|
|
||||||
|
|
||||||
# remove unneeded hosts
|
|
||||||
for name in del_keys:
|
|
||||||
del filtered_hosts[name]
|
|
||||||
|
|
||||||
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
|
|
||||||
# we map ranks to nodes correctly by maintaining host_info ordering.
|
|
||||||
ordered_hosts = collections.OrderedDict()
|
|
||||||
for host in host_info:
|
|
||||||
if host in filtered_hosts:
|
|
||||||
ordered_hosts[host] = filtered_hosts[host]
|
|
||||||
|
|
||||||
return ordered_hosts
|
|
||||||
|
|
||||||
|
|
||||||
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
def launch_multi_processes(args: Config) -> None:
|
||||||
active_devices = collections.OrderedDict()
|
|
||||||
for hostname, slots in device_pool.items():
|
|
||||||
active_devices[hostname] = list(range(slots))
|
|
||||||
|
|
||||||
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_multi_processes(args):
|
|
||||||
"""
|
"""
|
||||||
Launch multiple processes on a single node or multiple nodes.
|
Launch multiple processes on a single node or multiple nodes.
|
||||||
|
|
||||||
The overall logic can be summarized as the pseudo code below:
|
The overall logic can be summarized as the pseudo code below:
|
||||||
|
|
||||||
if hostfile given:
|
if hostfile given:
|
||||||
hostinfo = parse_hostfile(hostfile)
|
hostinfo = parse_hostfile(hostfile)
|
||||||
hostinfo = include_or_exclude_hosts(hostinfo)
|
hostinfo = include_or_exclude_hosts(hostinfo)
|
||||||
launch_on_multi_nodes(hostinfo)
|
launch_on_multi_nodes(hostinfo)
|
||||||
elif hosts given:
|
elif hosts given:
|
||||||
hostinfo = parse_hosts(hosts)
|
hostinfo = parse_hosts(hosts)
|
||||||
launch_on_multi_nodes(hostinfo)
|
launch_on_multi_nodes(hostinfo)
|
||||||
else:
|
else:
|
||||||
launch_on_current_node()
|
launch_on_current_node()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (Config): the arguments taken from command line
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert isinstance(args, Config)
|
assert isinstance(args, Config)
|
||||||
|
|
||||||
|
if args.nproc_per_node is None:
|
||||||
|
click.echo("--nproc_per_node did not receive any value")
|
||||||
|
exit()
|
||||||
|
|
||||||
# cannot accept hosts and hostfile at the same time
|
# cannot accept hosts and hostfile at the same time
|
||||||
if args.host and args.hostfile:
|
if args.host and args.hostfile:
|
||||||
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
|
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
|
||||||
|
|
||||||
# check if hostfile is given
|
# check if hostfile is given
|
||||||
if args.hostfile:
|
if args.hostfile:
|
||||||
device_pool = fetch_hostfile(args.hostfile)
|
device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port)
|
||||||
else:
|
active_device_pool = parse_device_filter(device_pool, args.include, args.exclude)
|
||||||
device_pool = None
|
|
||||||
|
|
||||||
# filter and only keep the ones needed
|
|
||||||
active_devices = None
|
|
||||||
if device_pool:
|
|
||||||
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
|
|
||||||
|
|
||||||
if args.num_nodes > 0:
|
if args.num_nodes > 0:
|
||||||
# only keep the first num_nodes to execute jobs
|
# only keep the first num_nodes to execute jobs
|
||||||
updated_active_devices = collections.OrderedDict()
|
updated_active_device_pool = HostInfoList()
|
||||||
for count, hostname in enumerate(active_devices.keys()):
|
for count, hostinfo in enumerate(active_device_pool):
|
||||||
if args.num_nodes == count:
|
if args.num_nodes == count:
|
||||||
break
|
break
|
||||||
updated_active_devices[hostname] = active_devices[hostname]
|
updated_active_device_pool.append(hostinfo)
|
||||||
active_devices = updated_active_devices
|
active_device_pool = updated_active_device_pool
|
||||||
|
else:
|
||||||
if args.nproc_per_node > 0:
|
active_device_pool = None
|
||||||
# only keep the first
|
|
||||||
updated_active_devices = collections.OrderedDict()
|
|
||||||
for hostname, active_devices in active_devices.items():
|
|
||||||
if len(active_devices) < args.nproc_per_node:
|
|
||||||
click.echo(
|
|
||||||
f"Error: The number of available GPUs on {hostname} is smaller than the argument nproc_per_node"
|
|
||||||
)
|
|
||||||
exit()
|
|
||||||
updated_active_devices[hostname] = active_devices[args.nproc_per_node]
|
|
||||||
active_devices = updated_active_devices
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
|
|
||||||
# use hosts if hostfile is not given
|
# use hosts if hostfile is not given
|
||||||
if args.host and active_devices is None:
|
if args.host and active_device_pool is None:
|
||||||
hostinfo = collections.OrderedDict()
|
active_device_pool = HostInfoList()
|
||||||
host_list = args.host.strip().split(',')
|
host_list = args.host.strip().split(NODE_SEP)
|
||||||
for hostname in host_list:
|
for hostname in host_list:
|
||||||
hostinfo[hostname] = args.nproc_per_node
|
hostinfo = HostInfo(hostname=hostname, port=args.ssh_port)
|
||||||
active_devices = hostinfo
|
active_device_pool.append(hostinfo)
|
||||||
|
|
||||||
# run on local node if not hosts or hostfile is given
|
if not active_device_pool:
|
||||||
if not active_devices:
|
# run on local node if not hosts or hostfile is given
|
||||||
if args.nproc_per_node == -1 or args.nproc_per_node > torch.cuda.device_count():
|
# add local node to host info list
|
||||||
nproc_per_node = torch.cuda.device_count()
|
active_device_pool = HostInfoList()
|
||||||
else:
|
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
|
||||||
nproc_per_node = args.nproc_per_node
|
active_device_pool.append(localhost_info)
|
||||||
|
|
||||||
if torch.__version__ <= "1.9":
|
# launch distributed processes
|
||||||
cmd = [
|
runner = MultiNodeRunner()
|
||||||
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
curr_path = os.path.abspath('.')
|
||||||
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
|
|
||||||
] + [args.user_script] + args.user_args
|
|
||||||
else:
|
|
||||||
cmd = [
|
|
||||||
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
|
|
||||||
f"--master_port={args.master_port}"
|
|
||||||
] + [args.user_script] + args.user_args
|
|
||||||
else:
|
|
||||||
runner = PDSHRunner(args)
|
|
||||||
|
|
||||||
curr_path = os.path.abspath('.')
|
# collect current path env
|
||||||
if 'PYTHONPATH' in env:
|
env = dict()
|
||||||
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
for k, v in os.environ.items():
|
||||||
else:
|
# do not support multi-line env var
|
||||||
env['PYTHONPATH'] = curr_path
|
if v and '\n' not in v:
|
||||||
|
env[k] = v
|
||||||
|
|
||||||
cmd = runner.get_cmd(env, active_devices, args)
|
# establish remote connection
|
||||||
|
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
|
||||||
|
|
||||||
result = subprocess.Popen(cmd, env=env)
|
# execute distributed launching command
|
||||||
result.wait()
|
for node_id, hostinfo in enumerate(active_device_pool):
|
||||||
if result.returncode > 0:
|
cmd = get_launch_command(master_addr=args.master_addr,
|
||||||
sys.exit(result.returncode)
|
master_port=args.master_port,
|
||||||
|
nproc_per_node=args.nproc_per_node,
|
||||||
|
user_script=args.user_script,
|
||||||
|
user_args=args.user_args,
|
||||||
|
node_rank=node_id,
|
||||||
|
num_nodes=len(active_device_pool),
|
||||||
|
extra_launch_args=args.extra_launch_args)
|
||||||
|
runner.send(hostinfo=hostinfo, cmd=cmd)
|
||||||
|
|
||||||
|
runner.recv_from_all()
|
||||||
|
runner.stop_all()
|
||||||
|
runner.recv_from_all()
|
||||||
|
|
|
@ -6,3 +6,4 @@ packaging
|
||||||
pre-commit
|
pre-commit
|
||||||
rich
|
rich
|
||||||
click
|
click
|
||||||
|
fabric
|
||||||
|
|
Loading…
Reference in New Issue