"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""

import torch
try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
except ImportError:
    raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')



def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
    """
    Arguments:
        qkv: (batch*seq, 3, nheads, headdim)
        batch_size: int.
        seq_len: int.
        sm_scale: float. The scaling of QK^T before applying softmax.
    Return:
        out: (total, nheads, headdim).
    """
    max_s = seq_len
    cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
        device=qkv.device)
    out = flash_attn_unpadded_qkvpacked_func(
        qkv, cu_seqlens, max_s, 0.0,
        softmax_scale=sm_scale, causal=False
    )
    return out


def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
    """
    Arguments:
        q: (batch*seq, nheads, headdim)
        kv: (batch*seq, 2, nheads, headdim)
        batch_size: int.
        seq_len: int.
        sm_scale: float. The scaling of QK^T before applying softmax.
    Return:
        out: (total, nheads, headdim).
    """
    cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
    cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
    out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
    return out