mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
330 lines
10 KiB
330 lines
10 KiB
import os |
|
import sys |
|
from typing import List |
|
|
|
import click |
|
import torch |
|
from packaging import version |
|
|
|
from colossalai.context import Config |
|
|
|
from .hostinfo import HostInfo, HostInfoList |
|
from .multinode_runner import MultiNodeRunner |
|
|
|
# Constants that define our syntax |
|
NODE_SEP = "," |
|
|
|
|
|
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: |
|
""" |
|
Parse the hostfile to obtain a list of hosts. |
|
|
|
A hostfile should look like: |
|
worker-0 |
|
worker-1 |
|
worker-2 |
|
... |
|
|
|
Args: |
|
hostfile_path (str): the path to the hostfile |
|
ssh_port (int): the port to connect to the host |
|
""" |
|
|
|
if not os.path.isfile(hostfile_path): |
|
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") |
|
exit() |
|
|
|
with open(hostfile_path, "r") as fd: |
|
device_pool = HostInfoList() |
|
|
|
for line in fd.readlines(): |
|
line = line.strip() |
|
if line == "": |
|
# skip empty lines |
|
continue |
|
|
|
# build the HostInfo object |
|
hostname = line.strip() |
|
hostinfo = HostInfo(hostname=hostname, port=ssh_port) |
|
|
|
if device_pool.has(hostname): |
|
click.echo(f"Error: found duplicate host {hostname} in the hostfile") |
|
exit() |
|
|
|
device_pool.append(hostinfo) |
|
return device_pool |
|
|
|
|
|
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: |
|
"""Parse an inclusion or exclusion string and filter a hostfile dictionary. |
|
|
|
Examples: |
|
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. |
|
exclude_str="worker-1" will use all available devices except worker-1. |
|
|
|
Args: |
|
device_pool (HostInfoList): a list of HostInfo objects |
|
include_str (str): --include option passed by user, default None |
|
exclude_str (str): --exclude option passed by user, default None |
|
|
|
Returns: |
|
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion |
|
""" |
|
|
|
# Ensure include/exclude are mutually exclusive |
|
if include_str and exclude_str: |
|
click.echo("--include and --exclude are mutually exclusive, only one can be used") |
|
exit() |
|
|
|
# no-op |
|
if include_str is None and exclude_str is None: |
|
return device_pool |
|
|
|
# Either build from scratch or remove items |
|
if include_str: |
|
parse_str = include_str |
|
filtered_hosts = HostInfoList() |
|
elif exclude_str: |
|
parse_str = exclude_str |
|
filtered_hosts = device_pool |
|
|
|
# foreach node in the list |
|
for node_config in parse_str.split(NODE_SEP): |
|
hostname = node_config |
|
hostinfo = device_pool.get_hostinfo(hostname) |
|
# sanity check hostname |
|
if not device_pool.has(hostname): |
|
click.echo(f"Error: Hostname '{hostname}' not found in hostfile") |
|
exit() |
|
|
|
if include_str: |
|
filtered_hosts.append(hostinfo) |
|
elif exclude_str: |
|
filtered_hosts.remove(hostname) |
|
|
|
return filtered_hosts |
|
|
|
|
|
def get_launch_command( |
|
master_addr: str, |
|
master_port: int, |
|
nproc_per_node: int, |
|
user_script: str, |
|
user_args: List[str], |
|
node_rank: int, |
|
num_nodes: int, |
|
extra_launch_args: str = None, |
|
) -> str: |
|
""" |
|
Generate a command for distributed training. |
|
|
|
Args: |
|
master_addr (str): the host of the master node |
|
master_port (str): the port of the master node |
|
nproc_per_node (str): the number of processes to launch on each node |
|
user_script (str): the user Python file |
|
user_args (str): the arguments for the user script |
|
node_rank (int): the unique ID for the node |
|
num_nodes (int): the number of nodes to execute jobs |
|
|
|
Returns: |
|
cmd (str): the command the start distributed training |
|
""" |
|
|
|
def _arg_dict_to_list(arg_dict): |
|
ret = [] |
|
|
|
for k, v in arg_dict.items(): |
|
if v: |
|
ret.append(f"--{k}={v}") |
|
else: |
|
ret.append(f"--{k}") |
|
return ret |
|
|
|
if extra_launch_args: |
|
extra_launch_args_dict = dict() |
|
for arg in extra_launch_args.split(","): |
|
if "=" in arg: |
|
k, v = arg.split("=") |
|
extra_launch_args_dict[k] = v |
|
else: |
|
extra_launch_args_dict[arg] = None |
|
extra_launch_args = extra_launch_args_dict |
|
else: |
|
extra_launch_args = dict() |
|
|
|
torch_version = version.parse(torch.__version__) |
|
assert torch_version.major >= 1 |
|
|
|
if torch_version.major == 1 and torch_version.minor < 9: |
|
# torch distributed launch cmd with torch < 1.9 |
|
cmd = [ |
|
sys.executable, |
|
"-m", |
|
"torch.distributed.launch", |
|
f"--nproc_per_node={nproc_per_node}", |
|
f"--master_addr={master_addr}", |
|
f"--master_port={master_port}", |
|
f"--nnodes={num_nodes}", |
|
f"--node_rank={node_rank}", |
|
] |
|
else: |
|
# extra launch args for torch distributed launcher with torch >= 1.9 |
|
default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port) |
|
|
|
# update rdzv arguments |
|
for key in default_torchrun_rdzv_args.keys(): |
|
if key in extra_launch_args: |
|
value = extra_launch_args.pop(key) |
|
default_torchrun_rdzv_args[key] = value |
|
|
|
if torch_version.major == 1 and torch_version.minor == 9: |
|
# torch distributed launch cmd with torch == 1.9 |
|
cmd = [ |
|
sys.executable, |
|
"-m", |
|
"torch.distributed.run", |
|
f"--nproc_per_node={nproc_per_node}", |
|
f"--nnodes={num_nodes}", |
|
f"--node_rank={node_rank}", |
|
] |
|
else: |
|
# torch distributed launch cmd with torch > 1.9 |
|
cmd = [ |
|
"torchrun", |
|
f"--nproc_per_node={nproc_per_node}", |
|
f"--nnodes={num_nodes}", |
|
f"--node_rank={node_rank}", |
|
] |
|
cmd += _arg_dict_to_list(default_torchrun_rdzv_args) |
|
|
|
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args |
|
cmd = " ".join(cmd) |
|
return cmd |
|
|
|
|
|
def launch_multi_processes(args: Config) -> None: |
|
""" |
|
Launch multiple processes on a single node or multiple nodes. |
|
|
|
The overall logic can be summarized as the pseudo code below: |
|
|
|
if hostfile given: |
|
hostinfo = parse_hostfile(hostfile) |
|
hostinfo = include_or_exclude_hosts(hostinfo) |
|
launch_on_multi_nodes(hostinfo) |
|
elif hosts given: |
|
hostinfo = parse_hosts(hosts) |
|
launch_on_multi_nodes(hostinfo) |
|
else: |
|
launch_on_current_node() |
|
|
|
Args: |
|
args (Config): the arguments taken from command line |
|
|
|
""" |
|
assert isinstance(args, Config) |
|
|
|
if args.nproc_per_node is None: |
|
click.echo("--nproc_per_node did not receive any value") |
|
exit() |
|
|
|
# cannot accept hosts and hostfile at the same time |
|
if args.host and args.hostfile: |
|
click.echo("Error: hostfile and hosts are mutually exclusive, only one is required") |
|
|
|
# check if hostfile is given |
|
if args.hostfile: |
|
device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port) |
|
active_device_pool = parse_device_filter(device_pool, args.include, args.exclude) |
|
|
|
if args.num_nodes > 0: |
|
# only keep the first num_nodes to execute jobs |
|
updated_active_device_pool = HostInfoList() |
|
for count, hostinfo in enumerate(active_device_pool): |
|
if args.num_nodes == count: |
|
break |
|
updated_active_device_pool.append(hostinfo) |
|
active_device_pool = updated_active_device_pool |
|
else: |
|
active_device_pool = None |
|
|
|
env = os.environ.copy() |
|
|
|
# use hosts if hostfile is not given |
|
if args.host and active_device_pool is None: |
|
active_device_pool = HostInfoList() |
|
host_list = args.host.strip().split(NODE_SEP) |
|
for hostname in host_list: |
|
hostinfo = HostInfo(hostname=hostname, port=args.ssh_port) |
|
active_device_pool.append(hostinfo) |
|
|
|
if not active_device_pool: |
|
# run on local node if not hosts or hostfile is given |
|
# add local node to host info list |
|
active_device_pool = HostInfoList() |
|
localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port) |
|
active_device_pool.append(localhost_info) |
|
|
|
# launch distributed processes |
|
runner = MultiNodeRunner() |
|
curr_path = os.path.abspath(".") |
|
|
|
# collect current path env |
|
env = dict() |
|
for k, v in os.environ.items(): |
|
# do not support multi-line env var |
|
if v and "\n" not in v: |
|
env[k] = v |
|
|
|
# establish remote connection |
|
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env) |
|
|
|
# overwrite master addr when num_nodes > 1 and not specified |
|
if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1": |
|
args.master_addr = active_device_pool.hostinfo_list[0].hostname |
|
|
|
# execute distributed launching command |
|
for node_id, hostinfo in enumerate(active_device_pool): |
|
cmd = get_launch_command( |
|
master_addr=args.master_addr, |
|
master_port=args.master_port, |
|
nproc_per_node=args.nproc_per_node, |
|
user_script=args.user_script, |
|
user_args=args.user_args, |
|
node_rank=node_id, |
|
num_nodes=len(active_device_pool), |
|
extra_launch_args=args.extra_launch_args, |
|
) |
|
runner.send(hostinfo=hostinfo, cmd=cmd) |
|
|
|
# start training |
|
msg_from_node = runner.recv_from_all() |
|
has_error = False |
|
|
|
# print node status |
|
click.echo("\n====== Training on All Nodes =====") |
|
for hostname, msg in msg_from_node.items(): |
|
click.echo(f"{hostname}: {msg}") |
|
|
|
# check if a process failed |
|
if msg == "failure": |
|
has_error = True |
|
|
|
# stop all nodes |
|
runner.stop_all() |
|
|
|
# receive the stop status |
|
msg_from_node = runner.recv_from_all() |
|
|
|
# print node status |
|
click.echo("\n====== Stopping All Nodes =====") |
|
for hostname, msg in msg_from_node.items(): |
|
click.echo(f"{hostname}: {msg}") |
|
|
|
# give the process an exit code |
|
# so that it behaves like a normal process |
|
if has_error: |
|
sys.exit(1) |
|
else: |
|
sys.exit(0)
|
|
|