from typing import Tuple

import torch
from torch.nn import functional as F


# This function is adapted from src/transformers/models/llama/modeling_llama.py
# in huggingface transformers repository
# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
    The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim)
    """
    if n_rep == 1:
        return hidden_states
    bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape
    hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
    return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)


def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"):
    assert q_len <= kv_len

    causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1)

    padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device)
    for i in range(bsz):
        cur_seq_len = kv_lengths[i].item()
        assert cur_seq_len <= kv_len
        padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf")

    padding_mask[:, :, -q_len:, -q_len:] += causal_mask

    return padding_mask


# Attention calculation adapted from HuggingFace transformers repository
# src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
def torch_attn_ref(
    q: torch.Tensor,  # [bsz, num_heads, q_len, head_dim]
    k: torch.Tensor,  # [bsz, num_heads, kv_len, head_dim]
    v: torch.Tensor,  # [bsz, num_heads, kv_len, head_dim]
    attention_mask: torch.Tensor,  # [bsz, 1, q_len, kv_len]
    bsz: int,
    q_len: int,
    kv_len: int,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
) -> torch.Tensor:
    assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim

    # repeat kv for GQA and MQA
    # k/v won't change if kv_group_num is 1
    assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads"
    kv_group_num = num_heads // num_kv_heads
    k = repeat_kv(k, kv_group_num)
    v = repeat_kv(v, kv_group_num)

    qk = torch.matmul(q, k.transpose(2, 3))
    attn_scores = qk / (head_dim**0.5)

    assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores"
    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
            )
        attn_scores = attn_scores + attention_mask

    attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype)
    out = torch.matmul(attn_weights, v)
    if out.size() != (bsz, num_heads, q_len, head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}"
        )
    out = out.transpose(1, 2).contiguous()
    out = out.view(-1, out.size(-2), out.size(-1))
    # out [bsz * q_len, num_heads, head_dim]
    return out


def mock_alloc_block_table_and_kvcache(
    k: torch.Tensor,
    v: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    context_lengths: torch.Tensor,
    num_seqs: int,
    max_num_blocks_per_seq: int,
    block_size: int,
) -> torch.Tensor:
    """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
    block_id = 0
    block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
    num_tokens_processed = 0
    for i, seq_len in enumerate(context_lengths.tolist()):
        right_bound = (seq_len + block_size - 1) // block_size  # open bound
        block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
        # Manually fill kv caches by copying from k and v
        for i in range(right_bound):
            if i == right_bound - 1:
                allocated_locs = seq_len % block_size or block_size
            else:
                allocated_locs = block_size
            k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
            v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
            k_cache[block_id, :, :, :allocated_locs] = k_block
            v_cache[block_id, :, :, :allocated_locs] = v_block

            num_tokens_processed += allocated_locs
            block_id += 1

    return block_tables


def mock_alloc_block_table_and_kvcache_v2(
    k: torch.Tensor,
    v: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    context_lengths: torch.Tensor,
    num_seqs: int,
    max_num_blocks_per_seq: int,
    block_size: int,
) -> torch.Tensor:
    """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
    block_id = 0
    block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
    num_tokens_processed = 0
    for i, seq_len in enumerate(context_lengths.tolist()):
        right_bound = (seq_len + block_size - 1) // block_size  # open bound
        block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
        # Manually fill kv caches by copying from k and v
        for i in range(right_bound):
            if i == right_bound - 1:
                allocated_locs = seq_len % block_size or block_size
            else:
                allocated_locs = block_size
            k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
            v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
            k_cache[block_id, :, :allocated_locs, :] = k_block
            v_cache[block_id, :, :allocated_locs, :] = v_block

            num_tokens_processed += allocated_locs
            block_id += 1

    return block_tables


def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
    # Allocate 1 token on the block table for each seqs in block tables.
    # It won't change provided context_lengths.
    # Consider max_block_id as the last physical block allocated
    # NOTE It assumes all the blocks preceding this block have been allocated
    max_block_id = torch.max(block_tables).item()
    # the indices on each block table representing the cache block to be allocated one more token
    alloc_local_block_indices = context_lengths // block_size
    # offsets of the token to be allocated on the target block (for each seq)
    alloc_block_offsets = context_lengths % block_size

    require_new_block = alloc_block_offsets == 0
    new_block_ids = torch.arange(
        max_block_id + 1,
        max_block_id + 1 + require_new_block.sum(),
        dtype=block_tables.dtype,
        device=block_tables.device,
    )

    if new_block_ids.numel():
        new_block_alloc_local_indices = alloc_local_block_indices[require_new_block]
        block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids


def generate_caches_and_block_tables(
    k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
    # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
    # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
    _, num_kv_heads, head_dim = k_unpad.shape
    cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
    k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
    v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
    # Mock allocation on block tables as well as blocked kv caches
    block_tables = mock_alloc_block_table_and_kvcache(
        k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
    )
    return k_cache, v_cache, block_tables


def generate_caches_and_block_tables_v2(
    k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
    # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
    # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
    _, num_kv_heads, head_dim = k_unpad.shape
    cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
    k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
    v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
    # Mock allocation on block tables as well as blocked kv caches
    block_tables = mock_alloc_block_table_and_kvcache_v2(
        k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
    )
    return k_cache, v_cache, block_tables


def convert_kv_unpad_to_padded(
    k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor:
    # Rebuild (batched) k/v with padding to be used by torch attention
    # input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
    # returns k/v padded    [bsz, num_kv_heads, max_seq_len, head_dim]
    _, num_kv_heads, head_dim = k_unpad.shape
    k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device)
    prev_len_sum = 0
    for i, seq_len in enumerate(kv_seq_lengths.tolist()):
        # left-side padding
        k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
        prev_len_sum += seq_len
    k_torch = k_torch.transpose(1, 2)
    return k_torch