import fcntl
import os
import time
from multiprocessing import Process

import pytest
import torch
import torch.distributed as dist

os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1"  # noqa  # pylint: disable=wrong-import-position
os.environ["NCCL_TIMEOUT"] = "5"
from internlm.utils.timeout import llm_timeout
from tests.test_utils.common_fixture import (  # noqa # pylint: disable=unused-import
    init_config,
)

WORLD_SIZE = 2


@llm_timeout(2, "fake_timeout_func")
def fake_timeout_func():
    time.sleep(10)


@llm_timeout(10, "nccl_timeout_func")
def nccl_timeout_func(rank):
    # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
    # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
    buff = torch.ones([64, 64]).cuda(rank)
    dist.all_reduce(buff)  # lazy communicator init
    torch.cuda.synchronize()
    if rank == 0:
        dist.all_reduce(buff)
        torch.cuda.synchronize()  # main thread will hang at here.
    else:
        time.sleep(9999)


@llm_timeout(10, "try_file_lock")
def try_file_lock(rank, stop_file_path):
    if rank == 1:
        time.sleep(5)

    with open(stop_file_path, "r", encoding="utf-8") as f:
        fcntl.flock(f, fcntl.LOCK_EX)  # rank 1 hang.
        if rank == 0:
            time.sleep(99999)  # rank 0 hang.
        f.seek(0)
        f.read()
        fcntl.flock(f, fcntl.LOCK_UN)


def local_timeout(rank, _):

    try:
        fake_timeout_func()
    except TimeoutError as e:
        print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
    else:
        assert False, "It should timeout!"


def gpc_timeout(rank, world_size):

    from internlm.initialize import initialize_distributed_env

    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "12377"
    initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)

    try:
        nccl_timeout_func(rank)
    except TimeoutError as e:
        print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
        time.sleep(5)  # wait rank 0 to be killed
    else:
        time.sleep(5)  # give some time to let Watchdog kill rank 0.
        assert False, "It should timeout!"


def file_lock_timeout(rank, _, stop_file_path):
    if rank == 0:
        with open(stop_file_path, "w"):
            pass
    try:
        try_file_lock(rank, stop_file_path)
    except TimeoutError as e:
        print(e, flush=True)
    else:
        assert False, "It should timeout!"
    finally:
        if rank == 0:
            os.remove(stop_file_path)


timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]


@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
def test_timeout(timeout_func_and_args):
    timeout_func, world_size, other_args = timeout_func_and_args
    procs = []
    for i in range(world_size):
        if other_args is None:
            args = (i, world_size)
        else:
            args = (i, world_size, other_args)
        proc = Process(target=timeout_func, args=args)
        proc.start()
        procs.append(proc)

    for proc in procs:
        proc.join(15)
        if proc.is_alive():
            proc.terminate()
            proc.join()