import torch
import torch.distributed as dist
from packaging import version
from torch import Tensor
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
from torch.distributed.utils import _p_assert


def _all_gather_flat_param(
    self,
    padded_unsharded_flat_param: Tensor,
) -> Tensor:
    """
    All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.

    Then switch to use the all-gathered tensor.
    """
    _p_assert(
        hasattr(self, "process_group") and hasattr(self, "world_size"),
        "Expects a process group and world size to have been set via `shard()`",
    )
    sharded_flat_param = self.flat_param.data
    expected_numel = sharded_flat_param.numel() * self.world_size
    _p_assert(
        padded_unsharded_flat_param.numel() == expected_numel,
        f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
    )

    pg = self._fake_process_group if self._use_fake_all_gather else self.process_group

    # HACK this should be handled by C10D
    if sharded_flat_param.is_cpu:  # type: ignore[attr-defined]
        tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
        work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
    else:
        if self._comm_hook is None:
            dist.all_gather_into_tensor(
                padded_unsharded_flat_param,
                sharded_flat_param,
                pg,
            )
        else:
            self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)

    if self._offload_params:
        # In case of offloading, `flat_param.data` (i.e. sharded param) is
        # created on the pre-unshard stream. We need to hand it over to the
        # unshard stream for all-gather
        _no_dispatch_record_stream(
            sharded_flat_param,
            self._device_handle.current_stream(),  # unshard_stream
        )
    return padded_unsharded_flat_param


def register_params_comm_hook(self, state: object, hook: callable):
    """Register a communication hook for FlatParamHandle.

    This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
    parameters across multiple workers.

    .. warning ::
        FSDP communication hook should be registered before running an initial forward pass
        and only once.

    Args:
        state (object): Passed to the hook to maintain any state information during the training process.
        hook (Callable): Callable, which has one of the following signatures:
                        1) ``hook: Callable[torch.Tensor] -> None``:
                        This function takes in a Python tensor, which represents
                        the full, flattened, unsharded gradient with respect to all variables
                        corresponding to the model this FSDP unit is wrapping
                        (that are not wrapped by other FSDP sub-units).
                        It then performs all necessary processing and returns ``None``;
                        2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
                        This function takes in two Python tensors, the first one represents
                        the full, flattened, unsharded gradient with respect to all variables
                        corresponding to the model this FSDP unit is wrapping
                        (that are not wrapped by other FSDP sub-units). The latter
                        represents a pre-sized tensor to store a chunk of a sharded gradient after
                        reduction.
                        In both cases, callable performs all necessary processing and returns ``None``.
                        Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
                        Callables with signature 2 are expected to handle gradient communication for sharded cases.

    """
    if not self.check_is_root():
        raise AssertionError("register_comm_hook can only be called on a root instance.")

    # if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
    #     raise AssertionError(
    #         f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
    #     )
    if self._handle._comm_hook is not None:
        raise AssertionError("A communication hook is already registered")
    if not callable(hook):
        raise ValueError(f"The communication hook must be callable but got {hook}")
    self._handle._comm_hook = hook
    self._handle._comm_hook_state = state


def patch_fsdp_params_comm_hook():
    if version.parse(torch.__version__) >= version.parse("2.2.0"):
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.fsdp._flat_param import FlatParamHandle

        FlatParamHandle._comm_hook = None
        FlatParamHandle._comm_hook_state = None
        FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
        FSDP.register_params_comm_hook = register_params_comm_hook
    else:
        raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")