import math

import torch
from torch.nn import functional as F


def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
    """
    adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
    """
    xq = xq.view(bs, seqlen, num_head, head_dim)
    xk = xk.view(bs, seqlen, num_head, head_dim)
    xv = xv.view(bs, seqlen, num_head, head_dim)
    mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
    mask[mask == 0.0] = -100000000.0
    mask = mask.repeat(bs, num_head, 1, 1)
    keys = xk
    values = xv
    xq = xq.transpose(1, 2)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)
    sm_scale = 1 / math.sqrt(head_dim)
    scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
    scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)

    output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
    return output