import torch
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from colossalai.gemini import Chunk, ChunkManager
from .placement_policy import PlacementPolicyFactory


class GeminiManager:
    """
    Stateful Tensor Manager, inspired from PatrickStar

    PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
    https://arxiv.org/abs/2108.05818
    """

    def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
        assert placement_policy in PlacementPolicyFactory.get_polocy_names()
        policy_cls = PlacementPolicyFactory.create(placement_policy)
        self._chunk_manager = chunk_manager
        self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
        self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
        self._compute_list: List[Tuple[Chunk, ...]] = []
        self._compute_idx: int = -1

        self._h2d_volume = 0
        self._d2h_volume = 0
        self._layout_time = 0
        self._evict_time = 0
        self._warmup = True
        self._comp_cuda_demand_time = 0

    def pre_iter(self):
        if self._mem_stats_collector and self._warmup:
            self._mem_stats_collector.start_collection()

    def post_iter(self):
        """This function must be called when each iteration finishes
        """
        if self._mem_stats_collector and self._warmup:
            self._mem_stats_collector.finish_collection()
        self._warmup = False
        self._compute_idx = -1
        self._h2d_volume = 0
        self._d2h_volume = 0
        self._layout_time = 0
        self._evict_time = 0
        self._comp_cuda_demand_time = 0

    def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
        """ Adjust the layout of statefuil tensor according to the information provided
        by mem_stats_collector, which should belongs to a Sharded Model.
        """
        # find stateful tensor in state COMPUTE
        start = time()
        self._record_chunks_order(chunks)
        cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
        self._layout_time += time() - start
        vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
                                                               cuda_demand=cuda_demand,
                                                               warmup=self._warmup,
                                                               compute_list=self._compute_list,
                                                               compute_idx=self._compute_idx)
        self._d2h_volume += vol
        self._evict_time += evict_time
        # move COMPUTE tensors to CUDA
        self._h2d_volume += cuda_demand

    @functools.lru_cache(maxsize=None)
    def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
        start = time()
        cuda_demand = 0
        for chunk in chunks:
            if chunk.device_type == 'cpu' or chunk.is_empty:
                cuda_demand += chunk.mem
        self._comp_cuda_demand_time += time() - start
        can_evict_chunks = []
        for chunk in self._chunk_manager.chunk_groups[group_name]:
            if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
                can_evict_chunks.append(chunk)
        return cuda_demand, can_evict_chunks

    def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
        self._compute_idx += 1
        if self._warmup and self._placement_policy.need_mem_stats:
            self._compute_list.append(chunks)

    @property
    def default_device(self):
        return self._placement_policy.get_default_device()

    def sample_overall_data(self):
        if self._mem_stats_collector:
            self._mem_stats_collector.sample_overall_data()

    def sample_model_data(self):
        if self._mem_stats_collector:
            self._mem_stats_collector.sample_model_data()

    @property
    def chunk_manager(self):
        return self._chunk_manager

    @property
    def cuda_margin_mem(self) -> Optional[float]:
        if self._mem_stats_collector:
            return self._mem_stats_collector.cuda_margin_mem
        return None

    @property
    def is_cuda_margin_mem_avail(self) -> bool:
        return self._placement_policy.need_mem_stats

    @staticmethod
    def get_default_device(policy_name: str) -> torch.device:
        return PlacementPolicyFactory.get_default_device(policy_name)