from typing import Dict, Iterator, Optional, Tuple, Union

import torch
from torch import nn

from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec

from .utils import InsertPostInitMethodToModuleSubClasses

# find named_params includes replica


def _named_params_with_replica(
    module: nn.Module,
    prefix: str = '',
    recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
    modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]

    for mod_prefix, mod in modules:
        for name, val in mod._parameters.items():
            if val is None:
                continue
            name = mod_prefix + ('.' if mod_prefix else '') + name
            yield name, val


def ColoModulize(module):
    """
    Replacing the parameters() and named_parameters() with our customized ones
    """

    module._colo_visited = True


class ColoInitContext(InsertPostInitMethodToModuleSubClasses):

    def __init__(self,
                 device: torch.device = torch.device('cpu'),
                 dtype: torch.dtype = torch.float,
                 default_pg: Optional[ProcessGroup] = None,
                 default_dist_spec=None):
        """
        Args:
            device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
            dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
            default_pg (ProcessGroup): the default process group for all initialized parameters.
            default_dist_spec: the default distributed specifications.
        """
        super().__init__()
        self._device = device
        self._dtype = dtype

        self._register_colo_modules()
        self._default_pg = default_pg
        self._default_dist_spec = default_dist_spec

    def _register_colo_modules(self):
        register_colo_module(torch.nn.Linear, ColoLinear())
        register_colo_module(torch.nn.Embedding, ColoEmbedding())

    def _pre_context_exec(self):
        pass

    def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
        """
        The function to call at the end of the constructor of each module.
        FIXME(fjr) The module may be passed to this function multiple times?
        """

        if hasattr(module, '_colo_visited'):
            return

        name_list = []
        for name, param in _named_params_with_replica(module):
            if isinstance(param, ColoTensor):
                continue

            split = name.rfind('.')
            if split >= 0:    # param in submodule
                module_name = name[:split]
                param_name = name[split + 1:]
            else:
                module_name = ''    # param in current module
                param_name = name
            name_list.append((module_name, param_name))

        replaced_tensors = dict(
        )    # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
        for module_name, param_name in name_list:
            submodule = module.get_submodule(module_name)
            param = submodule.get_parameter(param_name)
            if param in replaced_tensors:
                colo_param = replaced_tensors[param]
            else:
                # detaching tensor is necessary for optimizers.
                requires_grad = param.requires_grad

                # param is the global tensor.
                colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
                                           requires_grad=requires_grad)

                # if default_shard_plan exists, shard the param during initialization.
                # This can reduce the model size after initialization.
                # NOTE() embedding usually can not be correctly sharded. So I use except to handle
                # the param that can not be sharded by the default plan
                if self._default_pg is not None:
                    colo_param.set_process_group(self._default_pg)

                if self._default_dist_spec is not None:
                    try:
                        colo_param.set_dist_spec(self._default_dist_spec)
                    except:
                        pass

                replaced_tensors[param] = colo_param
            delattr(submodule, param_name)
            setattr(submodule, param_name, colo_param)
            colo_param.shared_param_modules.append(submodule)

        module.to(self._device)
        ColoModulize(module)