import torch
import gc
import psutil
from collections import namedtuple

from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.logging import get_dist_logger
from packaging import version

_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1


def _bytes_to_MB(val, decimal=2):
    """A byte-to-Megabyte converter, default using binary notation.

    :param val: X bytes to convert
    :return: X' MB
    """
    return round(val / (1024 * 1024), decimal)


# copy from PatrickStar
def _get_cpu_memory_info():
    ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
    try:
        # psutil reads the memory info from /proc/memory_info,
        # which results in returning the host memory instead of
        # that of container.
        # Here we try to read the container memory with method in:
        # https://stackoverflow.com/a/46213331/5163915
        mems = {}
        with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
            for line in f:
                fields = line.split()
                mems[fields[0]] = int(fields[1]) * 1024
        total = mems[b"MemTotal:"]
        free = mems[b"MemFree:"]
        cached = mems[b"Cached:"]
        buffers = mems[b"Buffers:"]
        used = total - free - cached - buffers
        if used < 0:
            used = total - free
        mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
    except FileNotFoundError:
        mems = psutil.virtual_memory()
        mem_info = ps_mem_info(
            total=mems.total,
            free=mems.free,
            cached=mems.cached,
            buffers=mems.buffers,
            used=mems.used,
        )
    return mem_info


def report_memory_usage(message, logger=None, report_cpu=False):
    """Calculate and print RAM usage (in GB)

    Args:
        message (str): A prefix message to add in the log.
        logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
        report_cpu (bool, optional): Whether to report CPU memory.

    Raises:
        EnvironmentError: Raise error if no distributed environment has been initialized.
    """
    if not gpc.is_initialized(ParallelMode.GLOBAL):
        raise EnvironmentError("No distributed environment is initialized")

    gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
    gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated())
    gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
    gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())

    full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
        + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"

    if report_cpu:
        # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
        gc.collect()
        vm_stats = psutil.virtual_memory()
        vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available)
        full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"

    if logger is None:
        logger = get_dist_logger()
    logger.info(full_log)

    # get the peak memory to report correct data, so reset the counter for the next call
    if hasattr(torch.cuda, "reset_peak_memory_stats"):    # pytorch 1.4+
        torch.cuda.reset_peak_memory_stats()


def colo_device_memory_capacity(device: torch.device) -> int:
    """
    Get the capacity of the memory of the device

    Args:
        device (torch.device): a device

    Returns:
        int: size in byte
    """
    assert isinstance(device, torch.device)
    if device.type == 'cpu':
        # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
        return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
    if device.type == 'cuda':
        return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION


def colo_device_memory_used(device: torch.device) -> int:
    """
    Get the device memory on device belonging to the current process.

    Args:
        device (torch.device): a device

    Returns:
        int: memory size in bytes
    """
    if device.type == 'cpu':
        mem_info = _get_cpu_memory_info()
        # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used.
        # Each process consumes the same amount of memory.
        ret = mem_info.used / gpc.num_processes_on_current_node
        return ret
    elif device.type == 'cuda':
        ret: int = torch.cuda.memory_allocated(device)
        # get the peak memory to report correct data, so reset the counter for the next call
        if hasattr(torch.cuda, "reset_peak_memory_stats"):    # pytorch 1.4+
            torch.cuda.reset_peak_memory_stats(device)
        return ret


def colo_set_process_memory_fraction(ratio: float) -> None:
    """colo_set_process_memory_fraction 

    set how much cuda memory used on the gpu belonging to the current process.

    Args:
        ratio (float): a ratio between 0. ~ 1.
    """
    if version.parse(torch.__version__) < version.parse('1.8'):
        logger = get_dist_logger('colo_set_process_memory_fraction')
        logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8')
        return
    global _GLOBAL_CUDA_MEM_FRACTION
    _GLOBAL_CUDA_MEM_FRACTION = ratio
    torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())


def colo_set_cpu_memory_capacity(size: int) -> None:
    global _GLOBAL_CPU_MEM_CAPACITY
    mem_info = _get_cpu_memory_info()
    total_size = mem_info.total
    if size <= total_size:
        _GLOBAL_CPU_MEM_CAPACITY = size
    else:
        _GLOBAL_CPU_MEM_CAPACITY = total_size


def colo_get_cpu_memory_capacity() -> int:
    """
    Get the cpu memory capacity. We may not use all of it.
    Returns:
        int: _description_
    """
    global _GLOBAL_CPU_MEM_CAPACITY
    if _GLOBAL_CPU_MEM_CAPACITY == -1:
        mem_info = _get_cpu_memory_info()
        return mem_info.total
    else:
        return _GLOBAL_CPU_MEM_CAPACITY