import fabric
from .hostinfo import HostInfo, HostInfoList
from multiprocessing import Pipe, Process
from multiprocessing import connection as mp_connection
import click


def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
                send_conn: mp_connection.Connection, env: dict) -> None:
    """
    Use fabric connection to execute command on local or remote hosts.

    Args:
        hostinfo (HostInfo): host information
        workdir (str): the directory to execute the command
        recv_conn (multiprocessing.connection.Connection): receive messages from the master sender
        send_conn (multiprocessing.connection.Connection): send messages to the master receiver
        env (dict): a dictionary for environment variables
    """

    fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
    finish = False
    env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])

    # keep listening until exit
    while not finish:
        # receive cmd
        cmds = recv_conn.recv()

        if cmds == 'exit':
            # exit from the loop
            finish = True
            break
        else:
            # execute the commands
            try:
                # cd to execute directory
                with fab_conn.cd(workdir):
                    # propagate the runtime environment
                    with fab_conn.prefix(f"export {env_msg}"):
                        if hostinfo.is_local_host:
                            # execute on the local machine
                            fab_conn.local(cmds, hide=False)
                        else:
                            # execute on the remote machine
                            fab_conn.run(cmds, hide=False)
                    send_conn.send('success')
            except:
                click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
                send_conn.send('failure')

    # shutdown
    send_conn.send("finish")
    fab_conn.close()


class MultiNodeRunner:
    """
    A runner to execute commands on an array of machines. This runner
    is inspired by Nezha (https://github.com/zhuzilin/NeZha).
    """

    def __init__(self):
        self.processes = {}
        self.master_send_conns = {}
        self.master_recv_conns = {}

    def connect(self, host_info_list: HostInfoList, workdir: str, env: dict) -> None:
        """
        Establish connections to a list of hosts

        Args:
            host_info_list (HostInfoList): a list of HostInfo objects
            workdir (str): the directory where command is executed
            env (dict): environment variables to propagate to hosts
        """
        for hostinfo in host_info_list:
            master_send_conn, worker_recv_conn = Pipe()
            master_recv_conn, worker_send_conn = Pipe()
            p = Process(target=run_on_host, args=(hostinfo, workdir, worker_recv_conn, worker_send_conn, env))
            p.start()
            self.processes[hostinfo.hostname] = p
            self.master_recv_conns[hostinfo.hostname] = master_recv_conn
            self.master_send_conns[hostinfo.hostname] = master_send_conn

    def send(self, hostinfo: HostInfo, cmd: str) -> None:
        """
        Send a command to a local/remote host.

        Args:
            hostinfo (HostInfo): host information
            cmd (str): the command to execute
        """

        assert hostinfo.hostname in self.master_send_conns, \
            f'{hostinfo} is not found in the current connections'
        conn = self.master_send_conns[hostinfo.hostname]
        conn.send(cmd)

    def stop_all(self) -> None:
        """
        Stop connections to all hosts.
        """

        for hostname, conn in self.master_send_conns.items():
            conn.send('exit')

    def recv_from_all(self) -> dict:
        """
        Receive messages from all hosts

        Returns:
            msg_from_node (dict): a dictionry which contains messages from each node
        """

        msg_from_node = dict()
        for hostname, conn in self.master_recv_conns.items():
            msg_from_node[hostname] = conn.recv()
        return msg_from_node