import os
import sys
from typing import List

import click
import torch
from packaging import version

from colossalai.context import Config

from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner

# Constants that define our syntax
NODE_SEP = ","


def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
    """
    Parse the hostfile to obtain a list of hosts.

    A hostfile should look like:
    worker-0
    worker-1
    worker-2
    ...

    Args:
        hostfile_path (str): the path to the hostfile
        ssh_port (int): the port to connect to the host
    """

    if not os.path.isfile(hostfile_path):
        click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
        exit()

    with open(hostfile_path, "r") as fd:
        device_pool = HostInfoList()

        for line in fd.readlines():
            line = line.strip()
            if line == "":
                # skip empty lines
                continue

            # build the HostInfo object
            hostname = line.strip()
            hostinfo = HostInfo(hostname=hostname, port=ssh_port)

            if device_pool.has(hostname):
                click.echo(f"Error: found duplicate host {hostname} in the hostfile")
                exit()

            device_pool.append(hostinfo)
    return device_pool


def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
    """Parse an inclusion or exclusion string and filter a hostfile dictionary.

    Examples:
        include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
        exclude_str="worker-1" will use all available devices except worker-1.

    Args:
        device_pool (HostInfoList): a list of HostInfo objects
        include_str (str): --include option passed by user, default None
        exclude_str (str): --exclude option passed by user, default None

    Returns:
        filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
    """

    # Ensure include/exclude are mutually exclusive
    if include_str and exclude_str:
        click.echo("--include and --exclude are mutually exclusive, only one can be used")
        exit()

    # no-op
    if include_str is None and exclude_str is None:
        return device_pool

    # Either build from scratch or remove items
    if include_str:
        parse_str = include_str
        filtered_hosts = HostInfoList()
    elif exclude_str:
        parse_str = exclude_str
        filtered_hosts = device_pool

    # foreach node in the list
    for node_config in parse_str.split(NODE_SEP):
        hostname = node_config
        hostinfo = device_pool.get_hostinfo(hostname)
        # sanity check hostname
        if not device_pool.has(hostname):
            click.echo(f"Error: Hostname '{hostname}' not found in hostfile")
            exit()

        if include_str:
            filtered_hosts.append(hostinfo)
        elif exclude_str:
            filtered_hosts.remove(hostname)

    return filtered_hosts


def get_launch_command(
    master_addr: str,
    master_port: int,
    nproc_per_node: int,
    user_script: str,
    user_args: List[str],
    node_rank: int,
    num_nodes: int,
    run_as_module: bool,
    extra_launch_args: str = None,
) -> str:
    """
    Generate a command for distributed training.

    Args:
        master_addr (str): the host of the master node
        master_port (str): the port of the master node
        nproc_per_node (str): the number of processes to launch on each node
        user_script (str): the user Python file
        user_args (str): the arguments for the user script
        node_rank (int): the unique ID for the node
        num_nodes (int): the number of nodes to execute jobs

    Returns:
        cmd (str): the command the start distributed training
    """

    def _arg_dict_to_list(arg_dict):
        ret = []

        for k, v in arg_dict.items():
            if v:
                ret.append(f"--{k}={v}")
            else:
                ret.append(f"--{k}")
        return ret

    if extra_launch_args:
        extra_launch_args_dict = dict()
        for arg in extra_launch_args.split(","):
            if "=" in arg:
                k, v = arg.split("=")
                extra_launch_args_dict[k] = v
            else:
                extra_launch_args_dict[arg] = None
        extra_launch_args = extra_launch_args_dict
    else:
        extra_launch_args = dict()

    torch_version = version.parse(torch.__version__)
    assert torch_version.major >= 1
    if torch_version.major < 2 and run_as_module:
        raise ValueError("Torch version < 2.0 does not support running as module")

    if torch_version.major == 1 and torch_version.minor < 9:
        # torch distributed launch cmd with torch < 1.9
        cmd = [
            sys.executable,
            "-m",
            "torch.distributed.launch",
            f"--nproc_per_node={nproc_per_node}",
            f"--master_addr={master_addr}",
            f"--master_port={master_port}",
            f"--nnodes={num_nodes}",
            f"--node_rank={node_rank}",
        ]
    else:
        # extra launch args for torch distributed launcher with torch >= 1.9
        default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port)

        # update rdzv arguments
        for key in default_torchrun_rdzv_args.keys():
            if key in extra_launch_args:
                value = extra_launch_args.pop(key)
                default_torchrun_rdzv_args[key] = value

        if torch_version.major == 1 and torch_version.minor == 9:
            # torch distributed launch cmd with torch == 1.9
            cmd = [
                sys.executable,
                "-m",
                "torch.distributed.run",
                f"--nproc_per_node={nproc_per_node}",
                f"--nnodes={num_nodes}",
                f"--node_rank={node_rank}",
            ]
        else:
            # torch distributed launch cmd with torch > 1.9
            cmd = [
                "torchrun",
                f"--nproc_per_node={nproc_per_node}",
                f"--nnodes={num_nodes}",
                f"--node_rank={node_rank}",
            ]
        cmd += _arg_dict_to_list(default_torchrun_rdzv_args)

    cmd += _arg_dict_to_list(extra_launch_args)
    if run_as_module:
        cmd.append("-m")
    cmd += [user_script] + user_args
    cmd = " ".join(cmd)
    return cmd


def launch_multi_processes(args: Config) -> None:
    """
    Launch multiple processes on a single node or multiple nodes.

    The overall logic can be summarized as the pseudo code below:

        if hostfile given:
            hostinfo = parse_hostfile(hostfile)
            hostinfo = include_or_exclude_hosts(hostinfo)
            launch_on_multi_nodes(hostinfo)
        elif hosts given:
            hostinfo = parse_hosts(hosts)
            launch_on_multi_nodes(hostinfo)
        else:
            launch_on_current_node()

    Args:
        args (Config): the arguments taken from command line

    """
    assert isinstance(args, Config)

    if args.nproc_per_node is None:
        click.echo("--nproc_per_node did not receive any value")
        exit()

    # cannot accept hosts and hostfile at the same time
    if args.host and args.hostfile:
        click.echo("Error: hostfile and hosts are mutually exclusive, only one is required")

    # check if hostfile is given
    if args.hostfile:
        device_pool = fetch_hostfile(args.hostfile, ssh_port=args.ssh_port)
        active_device_pool = parse_device_filter(device_pool, args.include, args.exclude)

        if args.num_nodes > 0:
            # only keep the first num_nodes to execute jobs
            updated_active_device_pool = HostInfoList()
            for count, hostinfo in enumerate(active_device_pool):
                if args.num_nodes == count:
                    break
                updated_active_device_pool.append(hostinfo)
            active_device_pool = updated_active_device_pool
    else:
        active_device_pool = None

    env = os.environ.copy()

    # use hosts if hostfile is not given
    if args.host and active_device_pool is None:
        active_device_pool = HostInfoList()
        host_list = args.host.strip().split(NODE_SEP)
        for hostname in host_list:
            hostinfo = HostInfo(hostname=hostname, port=args.ssh_port)
            active_device_pool.append(hostinfo)

    if not active_device_pool:
        # run on local node if not hosts or hostfile is given
        # add local node to host info list
        active_device_pool = HostInfoList()
        localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
        active_device_pool.append(localhost_info)

    # launch distributed processes
    runner = MultiNodeRunner()
    curr_path = os.path.abspath(".")

    # collect current path env
    env = dict()
    for k, v in os.environ.items():
        # do not support multi-line env var
        if v and "\n" not in v:
            env[k] = v

    # establish remote connection
    runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)

    # overwrite master addr when num_nodes > 1 and not specified
    if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
        args.master_addr = active_device_pool.hostinfo_list[0].hostname

    # execute distributed launching command
    for node_id, hostinfo in enumerate(active_device_pool):
        cmd = get_launch_command(
            master_addr=args.master_addr,
            master_port=args.master_port,
            nproc_per_node=args.nproc_per_node,
            user_script=args.user_script,
            user_args=args.user_args,
            node_rank=node_id,
            num_nodes=len(active_device_pool),
            run_as_module=args.m,
            extra_launch_args=args.extra_launch_args,
        )
        runner.send(hostinfo=hostinfo, cmd=cmd)

    # start training
    msg_from_node = runner.recv_from_all()
    has_error = False

    # print node status
    click.echo("\n====== Training on All Nodes =====")
    for hostname, msg in msg_from_node.items():
        click.echo(f"{hostname}: {msg}")

        # check if a process failed
        if msg == "failure":
            has_error = True

    # stop all nodes
    runner.stop_all()

    # receive the stop status
    msg_from_node = runner.recv_from_all()

    # print node status
    click.echo("\n====== Stopping All Nodes =====")
    for hostname, msg in msg_from_node.items():
        click.echo(f"{hostname}: {msg}")

    # give the process an exit code
    # so that it behaves like a normal process
    if has_error:
        sys.exit(1)
    else:
        sys.exit(0)