2022-04-19 02:59:44 +00:00
|
|
|
import argparse
|
|
|
|
from argparse import ArgumentParser, REMAINDER
|
|
|
|
import subprocess
|
|
|
|
import collections
|
|
|
|
import sys
|
|
|
|
import os
|
|
|
|
import torch
|
|
|
|
from colossalai.logging import get_dist_logger
|
2022-04-19 07:14:54 +00:00
|
|
|
from colossalai.context import Config
|
2022-04-19 02:59:44 +00:00
|
|
|
from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner
|
2022-04-19 07:14:54 +00:00
|
|
|
from copy import deepcopy
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
def fetch_hostfile(hostfile_path):
|
|
|
|
logger = get_dist_logger()
|
|
|
|
if not os.path.isfile(hostfile_path):
|
|
|
|
logger.warning("Unable to find hostfile, will proceed with training "
|
|
|
|
"with local resources only")
|
|
|
|
return None
|
|
|
|
|
|
|
|
# e.g., worker-0:16
|
|
|
|
with open(hostfile_path, 'r') as fd:
|
|
|
|
device_pool = collections.OrderedDict()
|
|
|
|
for line in fd.readlines():
|
|
|
|
line = line.strip()
|
|
|
|
if line == '':
|
|
|
|
# skip empty lines
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
hostname, slot_count = line.split(":")
|
|
|
|
slot_count = int(slot_count)
|
|
|
|
except ValueError as err:
|
|
|
|
logger.error("Hostfile is not formatted correctly, unable to "
|
|
|
|
"proceed with training.")
|
|
|
|
raise err
|
|
|
|
device_pool[hostname] = slot_count
|
|
|
|
|
|
|
|
return device_pool
|
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
def _stable_remove_duplicates(data):
|
|
|
|
# Create a new list in the same order as original but with duplicates
|
|
|
|
# removed, should never be more than ~16 elements so simple is best
|
|
|
|
new_list = []
|
|
|
|
for x in data:
|
|
|
|
if x not in new_list:
|
|
|
|
new_list.append(x)
|
|
|
|
return new_list
|
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
def parse_device_filter(host_info, include_str="", exclude_str=""):
|
|
|
|
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and
|
|
|
|
slots [0, 2] on worker-1.
|
|
|
|
exclude_str="worker-1:0" will use all available devices except
|
|
|
|
slot 0 on worker-1.
|
|
|
|
'''
|
|
|
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
# Constants that define our syntax
|
|
|
|
NODE_SEP = '@'
|
|
|
|
SLOT_LIST_START = ':'
|
|
|
|
SLOT_SEP = ','
|
|
|
|
|
|
|
|
# Ensure include/exclude are mutually exclusive
|
|
|
|
if (include_str != "") and (exclude_str != ""):
|
|
|
|
raise ValueError('include_str and exclude_str are mutually exclusive.')
|
|
|
|
|
|
|
|
# no-op
|
|
|
|
if (include_str == "") and (exclude_str == ""):
|
|
|
|
return host_info
|
|
|
|
|
|
|
|
# Either build from scratch or remove items
|
|
|
|
filtered_hosts = dict()
|
|
|
|
if include_str:
|
|
|
|
parse_str = include_str
|
|
|
|
if exclude_str != "":
|
|
|
|
filtered_hosts = deepcopy(host_info)
|
|
|
|
parse_str = exclude_str
|
|
|
|
|
|
|
|
# foreach node in the list
|
|
|
|
for node_config in parse_str.split(NODE_SEP):
|
|
|
|
# Node can either be alone or node:slot,slot,slot
|
|
|
|
if SLOT_LIST_START in node_config:
|
|
|
|
hostname, slots = node_config.split(SLOT_LIST_START)
|
|
|
|
slots = [int(x) for x in slots.split(SLOT_SEP)]
|
|
|
|
|
|
|
|
# sanity checks
|
|
|
|
if hostname not in host_info:
|
|
|
|
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
|
|
|
for slot in slots:
|
|
|
|
if slot not in host_info[hostname]:
|
|
|
|
raise ValueError(f"No slot '{slot}' specified on host '{hostname}'")
|
|
|
|
|
|
|
|
# If include string, build the list from here
|
|
|
|
if include_str:
|
|
|
|
filtered_hosts[hostname] = slots
|
|
|
|
elif exclude_str:
|
|
|
|
for slot in slots:
|
|
|
|
logger.info(f'removing {slot} from {hostname}')
|
|
|
|
filtered_hosts[hostname].remove(slot)
|
|
|
|
|
|
|
|
# User just specified the whole node
|
|
|
|
else:
|
|
|
|
hostname = node_config
|
|
|
|
# sanity check hostname
|
|
|
|
if hostname not in host_info:
|
|
|
|
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
|
|
|
|
|
|
|
|
if include_str:
|
|
|
|
filtered_hosts[hostname] = host_info[hostname]
|
|
|
|
elif exclude_str:
|
|
|
|
filtered_hosts[hostname] = []
|
|
|
|
|
|
|
|
# Post-processing to remove duplicates and empty nodes
|
|
|
|
del_keys = []
|
|
|
|
for hostname in filtered_hosts:
|
|
|
|
# Remove duplicates
|
|
|
|
filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname])
|
|
|
|
# Remove empty hosts
|
|
|
|
if len(filtered_hosts[hostname]) == 0:
|
|
|
|
del_keys.append(hostname)
|
|
|
|
for name in del_keys:
|
|
|
|
del filtered_hosts[name]
|
|
|
|
|
|
|
|
# Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure
|
|
|
|
# we map ranks to nodes correctly by maintaining host_info ordering.
|
|
|
|
ordered_hosts = collections.OrderedDict()
|
|
|
|
for host in host_info:
|
|
|
|
if host in filtered_hosts:
|
|
|
|
ordered_hosts[host] = filtered_hosts[host]
|
|
|
|
|
|
|
|
return ordered_hosts
|
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
def parse_inclusion_exclusion(device_pool, inclusion, exclusion):
|
|
|
|
active_devices = collections.OrderedDict()
|
|
|
|
for hostname, slots in device_pool.items():
|
|
|
|
active_devices[hostname] = list(range(slots))
|
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
return parse_device_filter(active_devices, include_str=inclusion, exclude_str=exclusion)
|
2022-04-19 02:59:44 +00:00
|
|
|
|
2022-04-19 07:14:54 +00:00
|
|
|
|
|
|
|
def launch_multi_processes(args):
|
|
|
|
assert isinstance(args, Config), f'expected args to be of type Config, but got {type(args)}'
|
|
|
|
|
|
|
|
# check
|
|
|
|
if args.hostfile:
|
|
|
|
device_pool = fetch_hostfile(args.hostfile)
|
|
|
|
else:
|
|
|
|
device_pool = None
|
2022-04-19 02:59:44 +00:00
|
|
|
|
|
|
|
active_devices = None
|
|
|
|
if device_pool:
|
2022-04-19 07:14:54 +00:00
|
|
|
active_devices = parse_inclusion_exclusion(device_pool, args.include, args.exclude)
|
2022-04-19 02:59:44 +00:00
|
|
|
if args.num_nodes > 0:
|
|
|
|
updated_active_devices = collections.OrderedDict()
|
|
|
|
for count, hostname in enumerate(active_devices.keys()):
|
|
|
|
if args.num_nodes == count:
|
|
|
|
break
|
|
|
|
updated_active_devices[hostname] = active_devices[hostname]
|
|
|
|
active_devices = updated_active_devices
|
|
|
|
|
|
|
|
if args.num_gpus > 0:
|
|
|
|
updated_active_devices = collections.OrderedDict()
|
|
|
|
for hostname in active_devices.keys():
|
|
|
|
updated_active_devices[hostname] = list(range(args.num_gpus))
|
|
|
|
active_devices = updated_active_devices
|
|
|
|
|
|
|
|
env = os.environ.copy()
|
|
|
|
|
|
|
|
if not active_devices:
|
|
|
|
if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count():
|
|
|
|
nproc_per_node = torch.cuda.device_count()
|
|
|
|
else:
|
|
|
|
nproc_per_node = args.num_gpus
|
|
|
|
if torch.__version__ <= "1.09":
|
2022-04-19 07:14:54 +00:00
|
|
|
cmd = [
|
|
|
|
sys.executable, "-u", "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
|
|
|
|
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
|
|
|
|
] + [args.user_script] + args.user_args
|
2022-04-19 02:59:44 +00:00
|
|
|
else:
|
2022-04-19 07:14:54 +00:00
|
|
|
cmd = [
|
|
|
|
"torchrun", f"--nproc_per_node={nproc_per_node}", f"--master_addr={args.master_addr}",
|
|
|
|
f"--master_port={args.master_port}"
|
|
|
|
] + [args.user_script] + args.user_args
|
2022-04-19 02:59:44 +00:00
|
|
|
else:
|
|
|
|
if args.launcher == "torch":
|
|
|
|
runner = PDSHRunner(args)
|
|
|
|
elif args.launcher == "mpi":
|
|
|
|
runner = OpenMPIRunner(args, device_pool)
|
|
|
|
elif args.launcher == "slurm":
|
|
|
|
runner = SLURMRunner(args, device_pool)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unknown launcher {args.launcher}")
|
|
|
|
|
|
|
|
if not runner.backend_exists():
|
|
|
|
raise RuntimeError(f"launcher '{args.launcher}' not installed.")
|
|
|
|
|
|
|
|
curr_path = os.path.abspath('.')
|
|
|
|
if 'PYTHONPATH' in env:
|
|
|
|
env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH']
|
|
|
|
else:
|
|
|
|
env['PYTHONPATH'] = curr_path
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
cmd = runner.get_cmd(env, active_devices, args)
|
2022-04-19 07:14:54 +00:00
|
|
|
|
2022-04-19 02:59:44 +00:00
|
|
|
result = subprocess.Popen(cmd, env=env)
|
|
|
|
result.wait()
|
|
|
|
if result.returncode > 0:
|
|
|
|
sys.exit(result.returncode)
|