2022-04-19 02:59:44 +00:00
|
|
|
import os
|
2023-01-11 07:17:17 +00:00
|
|
|
import sys
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import click
|
2022-04-19 02:59:44 +00:00
|
|
|
import torch
|
2023-01-11 07:17:17 +00:00
|
|
|
from packaging import version
|
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
from colossalai.context import Config
|
2023-01-11 07:17:17 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
from .hostinfo import HostInfo, HostInfoList
|
2023-01-11 07:17:17 +00:00
|
|
|
from .multinode_runner import MultiNodeRunner
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
# Constants that define our syntax
|
|
|
|
NODE_SEP = ','
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
|
|
|
|
"""
|
|
|
|
Parse the hostfile to obtain a list of hosts.
|
2023-01-11 07:17:17 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
A hostfile should look like:
|
|
|
|
worker-0
|
|
|
|
worker-1
|
|
|
|
worker-2
|
|
|
|
...
|
|
|
|
|
|
|
|
Args:
|
|
|
|
hostfile_path (str): the path to the hostfile
|
|
|
|
ssh_port (int): the port to connect to the host
|
|
|
|
"""
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
if not os.path.isfile(hostfile_path):
|
2022-04-20 02:38:21 +00:00
|
|
|
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
|
|
|
|
exit()
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
with open(hostfile_path, 'r') as fd:
|
2022-04-24 05:26:26 +00:00
|
|
|
device_pool = HostInfoList()
|
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
for line in fd.readlines():
|
|
|
|
line = line.strip()
|
|
|
|
if line == '':
|
|
|
|
# skip empty lines
|
|
|
|
continue
|
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
# build the HostInfo object
|
|
|
|
hostname = line.strip()
|
|
|
|
hostinfo = HostInfo(hostname=hostname, port=ssh_port)
|
|
|
|
|
|
|
|
if device_pool.has(hostname):
|
2022-04-20 02:38:21 +00:00
|
|
|
click.echo(f"Error: found duplicate host {hostname} in the hostfile")
|
|
|
|
exit()
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
device_pool.append(hostinfo)
|
|
|
|
return device_pool
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
|
2022-04-19 02:59:44 +00:00
|
|
|
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
|
|
|
|
|
|
|
Examples:
|
2022-04-24 05:26:26 +00:00
|
|
|
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
|
|
|
|
exclude_str="worker-1" will use all available devices except worker-1.
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
Args:
|
|
|
|
device_pool (HostInfoList): a list of HostInfo objects
|
|
|
|
include_str (str): --include option passed by user, default None
|
|
|
|
exclude_str (str): --exclude option passed by user, default None
|
2023-01-11 07:17:17 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
Returns:
|
|
|
|
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
|
|
|
|
'''
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
# Ensure include/exclude are mutually exclusive
|
2022-04-20 02:38:21 +00:00
|
|
|
if include_str and exclude_str:
|
|
|
|
click.echo("--include and --exclude are mutually exclusive, only one can be used")
|
|
|
|
exit()
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
# no-op
|
2022-04-20 02:38:21 +00:00
|
|
|
if include_str is None and exclude_str is None:
|
2022-04-24 05:26:26 +00:00
|
|
|
return device_pool
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
# Either build from scratch or remove items
|
|
|
|
if include_str:
|
|
|
|
parse_str = include_str
|
2022-04-24 05:26:26 +00:00
|
|
|
filtered_hosts = HostInfoList()
|
2022-04-20 02:38:21 +00:00
|
|
|
elif exclude_str:
|
2022-04-19 02:59:44 +00:00
|
|
|
parse_str = exclude_str
|
2022-04-24 05:26:26 +00:00
|
|
|
filtered_hosts = device_pool
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
# foreach node in the list
|
|
|
|
for node_config in parse_str.split(NODE_SEP):
|
2022-04-24 05:26:26 +00:00
|
|
|
hostname = node_config
|
|
|
|
hostinfo = device_pool.get_hostinfo(hostname)
|
|
|
|
# sanity check hostname
|
|
|
|
if not device_pool.has(hostname):
|
|
|
|
click.echo(f"Error: Hostname '{hostname}' not found in hostfile")
|
|
|
|
exit()
|
|
|
|
|
|
|
|
if include_str:
|
|
|
|
filtered_hosts.append(hostinfo)
|
|
|
|
elif exclude_str:
|
|
|
|
filtered_hosts.remove(hostname)
|
|
|
|
|
|
|
|
return filtered_hosts
|
|
|
|
|
|
|
|
|
|
|
|
def get_launch_command(
|
|
|
|
master_addr: str,
|
|
|
|
master_port: int,
|
|
|
|
nproc_per_node: int,
|
|
|
|
user_script: str,
|
|
|
|
user_args: List[str],
|
|
|
|
node_rank: int,
|
|
|
|
num_nodes: int,
|
|
|
|
extra_launch_args: str = None,
|
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
Generate a command for distributed training.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
master_addr (str): the host of the master node
|
|
|
|
master_port (str): the port of the master node
|
|
|
|
nproc_per_node (str): the number of processes to launch on each node
|
|
|
|
user_script (str): the user Python file
|
|
|
|
user_args (str): the arguments for the user script
|
|
|
|
node_rank (int): the unique ID for the node
|
|
|
|
num_nodes (int): the number of nodes to execute jobs
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
cmd (str): the command the start distributed training
|
|
|
|
"""
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
def _arg_dict_to_list(arg_dict):
|
|
|
|
ret = []
|
|
|
|
|
|
|
|
for k, v in arg_dict.items():
|
|
|
|
if v:
|
|
|
|
ret.append(f'--{k}={v}')
|
|
|
|
else:
|
|
|
|
ret.append(f'--{k}')
|
|
|
|
return ret
|
|
|
|
|
|
|
|
if extra_launch_args:
|
|
|
|
extra_launch_args_dict = dict()
|
|
|
|
for arg in extra_launch_args.split(','):
|
|
|
|
if '=' in arg:
|
|
|
|
k, v = arg.split('=')
|
|
|
|
extra_launch_args_dict[k] = v
|
|
|
|
else:
|
|
|
|
extra_launch_args_dict[arg] = None
|
|
|
|
extra_launch_args = extra_launch_args_dict
|
|
|
|
else:
|
|
|
|
extra_launch_args = dict()
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
torch_version = version.parse(torch.__version__)
|
2023-06-05 07:57:35 +00:00
|
|
|
assert torch_version.major >= 1
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
if torch_version.minor < 9:
|
|
|
|
cmd = [
|
|
|
|
sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
|
|
|
f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
|
|
|
|
f"--node_rank={node_rank}"
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
# extra launch args for torch distributed launcher with torch >= 1.9
|
2023-07-04 09:54:40 +00:00
|
|
|
default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port)
|
2022-04-24 05:26:26 +00:00
|
|
|
|
|
|
|
# update rdzv arguments
|
|
|
|
for key in default_torchrun_rdzv_args.keys():
|
|
|
|
if key in extra_launch_args:
|
|
|
|
value = extra_launch_args.pop(key)
|
|
|
|
default_torchrun_rdzv_args[key] = value
|
|
|
|
|
|
|
|
if torch_version.minor < 10:
|
|
|
|
cmd = [
|
|
|
|
sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
|
|
|
|
f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
cmd = [
|
|
|
|
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
|
|
|
|
]
|
|
|
|
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
|
|
|
|
cmd = ' '.join(cmd)
|
|
|
|
return cmd
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
def launch_multi_processes(args: Config) -> None:
|
2022-04-20 02:38:21 +00:00
|
|
|
"""
|
|
|
|
Launch multiple processes on a single node or multiple nodes.
|
|
|
|
|
|
|
|
The overall logic can be summarized as the pseudo code below:
|
2023-01-11 07:17:17 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
if hostfile given:
|
|
|
|
hostinfo = parse_hostfile(hostfile)
|
|
|
|
hostinfo = include_or_exclude_hosts(hostinfo)
|
|
|
|
launch_on_multi_nodes(hostinfo)
|
|
|
|
elif hosts given:
|
|
|
|
hostinfo = parse_hosts(hosts)
|
|
|
|
launch_on_multi_nodes(hostinfo)
|
|
|
|
else:
|
|
|
|
launch_on_current_node()
|
2023-01-11 07:17:17 +00:00
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
Args:
|
|
|
|
args (Config): the arguments taken from command line
|
|
|
|
|
2022-04-20 02:38:21 +00:00
|
|
|
"""
|
|
|
|
assert isinstance(args, Config)
|
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
if args.nproc_per_node is None:
|
|
|
|
click.echo("--nproc_per_node did not receive any value")
|
|
|
|
exit()
|
|
|
|
|
2022-04-20 02:38:21 +00:00
|
|
|
# cannot accept hosts and hostfile at the same time
|
|
|
|
if args.host and args.hostfile:
|
|
|
|
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-20 02:38:21 +00:00
|
|
|
# check if hostfile is given
|
2022-04-19 07:14:54 +00:00
|
|
|
if args.hostfile:
|
2022-04-24 05:26:26 +00:00
|
|
|
device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port)
|
|
|
|
active_device_pool = parse_device_filter(device_pool, args.include, args.exclude)
|
2022-04-20 02:38:21 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
if args.num_nodes > 0:
|
2022-04-20 02:38:21 +00:00
|
|
|
# only keep the first num_nodes to execute jobs
|
2022-04-24 05:26:26 +00:00
|
|
|
updated_active_device_pool = HostInfoList()
|
|
|
|
for count, hostinfo in enumerate(active_device_pool):
|
2022-04-19 02:59:44 +00:00
|
|
|
if args.num_nodes == count:
|
|
|
|
break
|
2022-04-24 05:26:26 +00:00
|
|
|
updated_active_device_pool.append(hostinfo)
|
|
|
|
active_device_pool = updated_active_device_pool
|
|
|
|
else:
|
|
|
|
active_device_pool = None
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
env = os.environ.copy()
|
|
|
|
|
2022-04-20 02:38:21 +00:00
|
|
|
# use hosts if hostfile is not given
|
2022-04-24 05:26:26 +00:00
|
|
|
if args.host and active_device_pool is None:
|
|
|
|
active_device_pool = HostInfoList()
|
|
|
|
host_list = args.host.strip().split(NODE_SEP)
|
2022-04-20 02:38:21 +00:00
|
|
|
for hostname in host_list:
|
2022-04-24 05:26:26 +00:00
|
|
|
hostinfo = HostInfo(hostname=hostname, port=args.ssh_port)
|
|
|
|
active_device_pool.append(hostinfo)
|
|
|
|
|
|
|
|
if not active_device_pool:
|
|
|
|
# run on local node if not hosts or hostfile is given
|
|
|
|
# add local node to host info list
|
|
|
|
active_device_pool = HostInfoList()
|
|
|
|
localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
|
|
|
|
active_device_pool.append(localhost_info)
|
|
|
|
|
|
|
|
# launch distributed processes
|
|
|
|
runner = MultiNodeRunner()
|
|
|
|
curr_path = os.path.abspath('.')
|
|
|
|
|
|
|
|
# collect current path env
|
|
|
|
env = dict()
|
|
|
|
for k, v in os.environ.items():
|
|
|
|
# do not support multi-line env var
|
|
|
|
if v and '\n' not in v:
|
|
|
|
env[k] = v
|
|
|
|
|
|
|
|
# establish remote connection
|
|
|
|
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
|
|
|
|
|
2023-08-28 09:59:11 +00:00
|
|
|
# overwrite master addr when num_nodes > 1 and not specified
|
|
|
|
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
|
|
|
|
args.master_addr = active_device_pool.hostinfo_list[0].hostname
|
|
|
|
|
2022-04-24 05:26:26 +00:00
|
|
|
# execute distributed launching command
|
|
|
|
for node_id, hostinfo in enumerate(active_device_pool):
|
|
|
|
cmd = get_launch_command(master_addr=args.master_addr,
|
|
|
|
master_port=args.master_port,
|
|
|
|
nproc_per_node=args.nproc_per_node,
|
|
|
|
user_script=args.user_script,
|
|
|
|
user_args=args.user_args,
|
|
|
|
node_rank=node_id,
|
|
|
|
num_nodes=len(active_device_pool),
|
|
|
|
extra_launch_args=args.extra_launch_args)
|
|
|
|
runner.send(hostinfo=hostinfo, cmd=cmd)
|
|
|
|
|
2023-01-11 07:17:17 +00:00
|
|
|
# start training
|
|
|
|
msg_from_node = runner.recv_from_all()
|
|
|
|
has_error = False
|
|
|
|
|
|
|
|
# print node status
|
|
|
|
click.echo("\n====== Training on All Nodes =====")
|
|
|
|
for hostname, msg in msg_from_node.items():
|
|
|
|
click.echo(f"{hostname}: {msg}")
|
|
|
|
|
|
|
|
# check if a process failed
|
|
|
|
if msg == "failure":
|
|
|
|
has_error = True
|
|
|
|
|
|
|
|
# stop all nodes
|
2022-04-24 05:26:26 +00:00
|
|
|
runner.stop_all()
|
2023-01-11 07:17:17 +00:00
|
|
|
|
|
|
|
# receive the stop status
|
|
|
|
msg_from_node = runner.recv_from_all()
|
|
|
|
|
2023-06-02 07:02:45 +00:00
|
|
|
# print node status
|
2023-01-11 07:17:17 +00:00
|
|
|
click.echo("\n====== Stopping All Nodes =====")
|
|
|
|
for hostname, msg in msg_from_node.items():
|
|
|
|
click.echo(f"{hostname}: {msg}")
|
|
|
|
|
|
|
|
# give the process an exit code
|
|
|
|
# so that it behaves like a normal process
|
|
|
|
if has_error:
|
|
|
|
sys.exit(1)
|
|
|
|
else:
|
|
|
|
sys.exit(0)
|