from ..base_extension import _Extension


class FlashAttentionDaoCudaExtension(_Extension):
    def __init__(self):
        super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)

    def is_available(self) -> bool:
        # cuda extension can only be built if cuda is available
        try:
            import torch

            from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func  # noqa
            from flash_attn.bert_padding import index_first_axis, pad_input  # noqa

            cuda_available = torch.cuda.is_available()
        except:
            cuda_available = False
        return cuda_available

    def assert_compatible(self) -> bool:
        pass

    def build_aot(self) -> None:
        raise NotImplementedError(
            "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'."
        )

    def build_jit(self) -> None:
        raise NotImplementedError(
            "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
        )

    def load(self):
        from typing import Optional

        import torch
        from einops import rearrange
        from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func
        from flash_attn.bert_padding import index_first_axis, pad_input

        def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
            return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)

        def flash_attention(
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            dropout_p: float = 0.0,
            scale: Optional[float] = None,
            attention_mask: Optional[torch.Tensor] = None,
            is_causal: bool = False,
            cu_seqlens_q: Optional[torch.Tensor] = None,
            cu_seqlens_kv: Optional[torch.Tensor] = None,
            max_seqlen_q: Optional[int] = None,
            max_seqlen_kv: Optional[int] = None,
            q_indices: Optional[torch.Tensor] = None,
            kv_indices: Optional[torch.Tensor] = None,
        ):
            # [B, N, S, D] -> [B, S, N, D]
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
            b, s_q = q.shape[:2]
            if cu_seqlens_q is not None:
                # padded / padded causal
                # unpad input: [B, S, N, D] -> [T, N, D]
                q = _unpad_input(q, q_indices)
                kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
                attn_output = flash_attn_varlen_kvpacked_func(
                    q,
                    kv,
                    cu_seqlens_q,
                    cu_seqlens_kv,
                    max_seqlen_q,
                    max_seqlen_kv,
                    dropout_p=dropout_p,
                    softmax_scale=scale,
                    causal=is_causal,
                )
                # pad output: [T, N, D] -> [B, S, N, D]
                attn_output = pad_input(attn_output, q_indices, b, s_q)
            else:
                # causal / no attn mask
                attn_output = flash_attn_func(
                    q,
                    k,
                    v,
                    dropout_p=dropout_p,
                    softmax_scale=scale,
                    causal=is_causal,
                )
            # [B, S, N, D] -> [B, N, S, D]
            return attn_output.transpose(1, 2)

        return flash_attention