from dataclasses import dataclass
from typing import Iterable, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange

from colossalai.utils.device import get_current_device


class Unpad(torch.autograd.Function):
    """
    Adapted from
    https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
    """

    @staticmethod
    def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
        ctx.save_for_backward(indices)
        # [b, s, ...]
        assert tensor.ndim >= 3
        ctx.bsz = tensor.shape[0]
        out = rearrange(tensor, "b s ... -> (b s) ...")
        ctx.shape = out.shape
        # [ntokens, ...]
        return out[indices]

    @staticmethod
    def backward(ctx, grad_output):
        (indices,) = ctx.saved_tensors
        # [ntokens, ...]
        grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
        grad[indices] = grad_output
        grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
        # [b, s, ...]
        return grad, None


class Repad(torch.autograd.Function):
    """
    Adapted from
    https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
    """

    @staticmethod
    def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
        ctx.save_for_backward(indices)
        # [ntokens, ...]
        tensor = tensor
        out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
        # [b*s, ...]
        out[indices] = tensor
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (indices,) = ctx.saved_tensors
        # [b*s, ...]
        grad = grad_output[indices]
        # [ntokens, ...]
        return grad, None, None, None


@dataclass
class SeqLenInfo:
    seqlens: Iterable[int] = None
    indices: torch.Tensor = None
    max_seqlen: int = None
    cu_seqlens: torch.Tensor = None

    @staticmethod
    def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
        if attn_mask is not None:
            indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
            seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
        else:
            batch_size, tgt_len = size[0], size[1]
            indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
            seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
        max_seqlen = max(seqlens)
        cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
        return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)