mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
4.2 KiB
127 lines
4.2 KiB
from multiprocessing import Pipe, Process |
|
from multiprocessing import connection as mp_connection |
|
|
|
import click |
|
import fabric |
|
|
|
from .hostinfo import HostInfo, HostInfoList |
|
|
|
|
|
def run_on_host( |
|
hostinfo: HostInfo, |
|
workdir: str, |
|
recv_conn: mp_connection.Connection, |
|
send_conn: mp_connection.Connection, |
|
env: dict, |
|
) -> None: |
|
""" |
|
Use fabric connection to execute command on local or remote hosts. |
|
|
|
Args: |
|
hostinfo (HostInfo): host information |
|
workdir (str): the directory to execute the command |
|
recv_conn (multiprocessing.connection.Connection): receive messages from the master sender |
|
send_conn (multiprocessing.connection.Connection): send messages to the master receiver |
|
env (dict): a dictionary for environment variables |
|
""" |
|
|
|
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) |
|
finish = False |
|
env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()]) |
|
|
|
# keep listening until exit |
|
while not finish: |
|
# receive cmd |
|
cmds = recv_conn.recv() |
|
|
|
if cmds == "exit": |
|
# exit from the loop |
|
finish = True |
|
break |
|
else: |
|
# execute the commands |
|
try: |
|
# cd to execute directory |
|
with fab_conn.cd(workdir): |
|
# propagate the runtime environment |
|
with fab_conn.prefix(f"export {env_msg}"): |
|
if hostinfo.is_local_host: |
|
# execute on the local machine |
|
fab_conn.local(cmds, hide=False) |
|
else: |
|
# execute on the remote machine |
|
fab_conn.run(cmds, hide=False) |
|
send_conn.send("success") |
|
except Exception as e: |
|
click.echo( |
|
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" |
|
) |
|
send_conn.send("failure") |
|
|
|
# shutdown |
|
send_conn.send("finish") |
|
fab_conn.close() |
|
|
|
|
|
class MultiNodeRunner: |
|
""" |
|
A runner to execute commands on an array of machines. This runner |
|
is inspired by Nezha (https://github.com/zhuzilin/NeZha). |
|
""" |
|
|
|
def __init__(self): |
|
self.processes = {} |
|
self.master_send_conns = {} |
|
self.master_recv_conns = {} |
|
|
|
def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None: |
|
""" |
|
Establish connections to a list of hosts |
|
|
|
Args: |
|
host_info_list (HostInfoList): a list of HostInfo objects |
|
workdir (str): the directory where command is executed |
|
env (dict): environment variables to propagate to hosts |
|
""" |
|
for hostinfo in host_info_list: |
|
master_send_conn, worker_recv_conn = Pipe() |
|
master_recv_conn, worker_send_conn = Pipe() |
|
p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env)) |
|
p.start() |
|
self.processes[hostinfo.hostname] = p |
|
self.master_recv_conns[hostinfo.hostname] = master_recv_conn |
|
self.master_send_conns[hostinfo.hostname] = master_send_conn |
|
|
|
def send(self, hostinfo: HostInfo, cmd: str) -> None: |
|
""" |
|
Send a command to a local/remote host. |
|
|
|
Args: |
|
hostinfo (HostInfo): host information |
|
cmd (str): the command to execute |
|
""" |
|
|
|
assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections" |
|
conn = self.master_send_conns[hostinfo.hostname] |
|
conn.send(cmd) |
|
|
|
def stop_all(self) -> None: |
|
""" |
|
Stop connections to all hosts. |
|
""" |
|
|
|
for hostname, conn in self.master_send_conns.items(): |
|
conn.send("exit") |
|
|
|
def recv_from_all(self) -> dict: |
|
""" |
|
Receive messages from all hosts |
|
|
|
Returns: |
|
msg_from_node (dict): a dictionary which contains messages from each node |
|
""" |
|
|
|
msg_from_node = dict() |
|
for hostname, conn in self.master_recv_conns.items(): |
|
msg_from_node[hostname] = conn.recv() |
|
return msg_from_node
|
|
|