from enum import Enum
from os import stat
from typing import Dict, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy


class OptimState(Enum):
    SCALED = 1
    UNSCALED = 2


class ShardedOptimizerV2(ColossalaiOptimizer):
    """A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO).

    By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.

    We apply the Device-aware Operator Placement technique for OS placement from the following paper.

    `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_

    GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
    which is detected by a runtime memory tracer. 

    We place as many OS chunks in the margin space as possible. 

    The size of margin space can be controlled by ``gpu_margin_mem_ratio``.
    If it is set as ``0.0``, it is the same as classical ZeRO optimizer.

    Note:
        You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``.

    Note:
        Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`,
        if you set ``gpu_margin_mem_ratio > 0``.

    Args:
        sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
            shard strategy provided by sharded model to shard param fp32 tensors.
        optimizer (Optimizer): An Optimizer instance.
        gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) 
            which will be used when using hybrid CPU optimizer. 
            This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
            Defaults to 0.0.
        initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
        min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
        growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
        backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
        growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
        hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
        max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
        dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
        mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.

    .. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
        https://arxiv.org/abs/2108.05818
    """

    def __init__(self,
                 sharded_model: ShardedModelV2,
                 optimizer: Optimizer,
                 gpu_margin_mem_ratio: float = 0.0,
                 initial_scale: float = 2**32,
                 min_scale: float = 1,
                 growth_factor: float = 2,
                 backoff_factor: float = 0.5,
                 growth_interval: int = 1000,
                 hysteresis: int = 2,
                 max_scale: float = 2**32,
                 dp_process_group: Optional[ProcessGroup] = None,
                 mp_process_group: Optional[ProcessGroup] = None,
                 verbose: bool = False) -> None:
        assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
        assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'

        super().__init__(optimizer)
        self.shard_strategy = sharded_model.shard_strategy
        self.model: ShardedModelV2 = sharded_model

        self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
        assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
        # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
        # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
        # and it must set `num_fp32_shards_per_param` correctly
        self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
            optimizer, 'num_fp32_shards_per_param', 0) >= 2
        self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
        self.optim_state: OptimState = OptimState.UNSCALED
        self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
        self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
        # Grad scaler
        self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
                                             min_scale=min_scale,
                                             growth_factor=growth_factor,
                                             backoff_factor=backoff_factor,
                                             growth_interval=growth_interval,
                                             hysteresis=hysteresis,
                                             max_scale=max_scale)
        self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
        self._logger = get_dist_logger("ShardedOptimizerV2")
        self._verbose = verbose

        # Store fp32 param shards
        self._register_master_weight()
        if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
                                                               AutoTensorPlacementPolicy):
            self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
                                 ranks=[0])

        if self._verbose:
            self._logger.debug(
                f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])

        self._use_memory_tracer = self.model.use_memory_tracer

    @property
    def loss_scale(self):
        return self.grad_scaler.scale.item()

    def get_memory_usage(self) -> Tuple[int, int]:
        """ Get the memory usage of the optimizer. Including master_params (param fp32),
        momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)

        Returns:
            Tuple[int, int]: cuda/cpu memory usage in Byte.
        """
        cuda_use = 0
        cpu_use = 0

        def update_mem_use(t):
            nonlocal cuda_use
            nonlocal cpu_use
            t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
            cuda_use += t_cuda_use
            cpu_use += t_cpu_use

        for _, p_fp32 in self.master_params.items():
            update_mem_use(p_fp32)
        for group in self.optim.param_groups:
            for p in group['params']:
                state = self.optim.state[p]
                for k, v in state.items():
                    update_mem_use(v)

        return cuda_use, cpu_use

    def zero_grad(self, *args, **kwargs):
        self._zero_grad()

    def backward(self, loss: Tensor) -> None:
        loss = self.loss_scale * loss
        self.optim_state = OptimState.SCALED
        self.model.backward(loss)

    def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
        # This function is called except the last stage of pipeline parallel
        # It receives the scaled grad from the previous rank
        # No need to scale the grad again
        # Need to unscale when optimizing
        self.optim_state = OptimState.SCALED
        self.model.backward_by_grad(tensor, grad)

    def clip_grad_norm(self, model: nn.Module, max_norm: float):
        if self.optim_state == OptimState.SCALED:
            self._prepare_grads()
            self._unscale_grads()
        return super().clip_grad_norm(model, max_norm)

    def step(self, *args, **kwargs):

        # unscale grads if scaled
        if self.optim_state == OptimState.SCALED:
            self._prepare_grads()
            self._unscale_grads()

        self._maybe_move_fp32_shards()
        found_inf = self._check_overflow()
        self.grad_scaler.update(found_inf)

        if found_inf:
            self._logger.warning('found inf during ShardedOptimV2 step')
            self._zero_grad(recover_data=True)
            return

        self._point_param_fp16_to_master_param()

        if self._verbose:
            gpu_mem, cpu_mem = self.get_memory_usage()
            self._logger.debug(
                f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
                ranks=[0])
        ret = self.optim.step(*args, **kwargs)

        if self._verbose:
            gpu_mem, cpu_mem = self.get_memory_usage()
            self._logger.debug(
                f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
                ranks=[0])

        self._copy_master_model_to_model_fp16()
        return ret

    def _check_overflow(self):
        # clear previous overflow record
        self._found_overflow.fill_(self.model.overflow_counter)

        # all-reduce across dp group
        dist.all_reduce(self._found_overflow, group=self.dp_process_group)

        # all-reduce over model parallel group
        dist.all_reduce(self._found_overflow, group=self.mp_process_group)

        return self._found_overflow.item() > 0

    def _unscale_grads(self):
        assert self.optim_state == OptimState.SCALED
        for group in self.optim.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.data.div_(self.loss_scale)
        self.optim_state = OptimState.UNSCALED

    def _zero_grad(self, recover_data: bool = False):
        """zero grad and maybe recover fp16 params
        When `reuse_fp16_shard` is enabled,
        p.colo_attr.sharded_data_tensor stores grad here.
        We have to recover them from fp32 params.

        Args:
            recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
        """
        # We must set grad to None
        # Because grad here is sharded
        # But next backward pass will create a full grad first
        # Which leads to wrong accumulation
        self.optim.zero_grad(set_to_none=True)
        for group in self.optim.param_groups:
            for p in group['params']:
                # p.colo_attr.sharded_data_tensor stores grad now
                # we have to recover fp16 param
                reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
                if recover_data and reuse_fp16_shard:
                    self._copy_master_param_to_param_fp16(p)
                else:
                    # release saved gradient
                    p.colo_attr.saved_grad.set_null()
        self.model.overflow_counter = 0    # set overflow counter to zero

    def sync_grad(self):
        pass

    def _register_master_weight(self):
        self.master_params: Dict[Parameter, StatefulTensor] = {}
        for group in self.optim.param_groups:
            for p in group['params']:
                assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
                shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
                if shard_flag:
                    # we always shard replicated paramters
                    self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
                self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
                if shard_flag:
                    # In this branch, there's no need to shard param
                    # So we gather here
                    self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)

    def _maybe_move_fp32_shards(self):
        if self._should_move_fp32_shards_h2d:
            self._should_move_fp32_shards_h2d = False
            available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
            fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
            fp32_shards_used_cuda_margin_mem = 0
            for group in self.optim.param_groups:
                for p in group['params']:
                    if p.colo_attr.saved_grad.is_null():
                        continue
                    shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
                    if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
                        colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
                        colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
                        p.colo_attr.offload_grad = False
                        fp32_shards_used_cuda_margin_mem += shard_mem
                        state = self.optim.state[p]
                        for k, v in state.items():
                            if isinstance(v, Tensor):
                                state[k] = v.cuda()

    def _prepare_grads(self):
        for group in self.optim.param_groups:
            for p in group['params']:
                if p.colo_attr.saved_grad.is_null():
                    continue
                p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
                # If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU
                if not p.colo_attr.offload_grad:
                    colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
                # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
                # If we change p.grad directly
                # it may raise error because of different shape/dtype/device of p.data and p.grad
                # We just set p.data = p.colo_attr.saved_grad.payload here
                p.data = p.colo_attr.grad_payload
                p.grad = p.colo_attr.grad_payload
                # Set p.data to empty tensor, in case of memory leaking
                p.colo_attr.set_data_none()

    def _point_param_fp16_to_master_param(self):
        # assign master param pointers to p.data.
        # We will not trigger data copy here.
        for group in self.optim.param_groups:
            for p in group['params']:
                self.master_params[p].trans_state(TensorState.COMPUTE)
                p.data = self.master_params[p].payload
                # Now p.data is sharded
                # So optimizer states are sharded naturally

    def _copy_master_model_to_model_fp16(self):
        # Copy master param data (fp32) to payload of colo_attr (fp16)
        # TODO() improve efficiency by gathering tensors into a chunk and transfering
        # a chunk.
        for group in self.optim.param_groups:
            for p in group['params']:
                self._copy_master_param_to_param_fp16(p)

    def _copy_master_param_to_param_fp16(self, p):
        # flush gradient
        if p.colo_attr.sharded_data_tensor.payload_size == 0:
            # here reuse_fp16_shard is True
            # in order to use copy below, we should give sharded data tensor a payload
            p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
        else:
            p.colo_attr.saved_grad.set_null()

        p.data = self.master_params[p].payload

        # we need to allocate new memory for keep_not_shard paramters
        # in order to use copy, otherwise, the sizes of tensor is not compatible
        if p.colo_attr.data_payload.numel() != p.data.numel():
            p.colo_attr.data_payload_reset(
                torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))

        # TODO() optimize this line CPU (fp32) -> GPU (fp16)
        p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
        p.colo_attr.set_data_none()

        if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
            # We gather full fp16 param here
            p.colo_attr.sharded_data_tensor.is_sharded = True    # since only gradient is sharded, we should set to True
            self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)

        self.master_params[p].trans_state(TensorState.HOLD)

    def state_dict(self):
        optim_state_dict = super().state_dict()
        scaler_state_dict = self.grad_scaler.state_dict()
        optim_state_dict['scaler'] = scaler_state_dict
        return optim_state_dict

    def load_state_dict(self, *args, **kwargs):
        if 'scaler' not in args[0]:
            self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
        else:
            scaler_state_dict = args[0].pop('scaler')
            self.grad_scaler.load_state_dict(scaler_state_dict)
        super().load_state_dict(*args, **kwargs)
        for group in self.optim.param_groups:
            for p in group['params']:
                state = self.optim.state[p]
                for k, v in state.items():
                    if isinstance(v, Tensor):
                        state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device)