#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
from torch import distributed as dist

from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd


class RingQK(torch.autograd.Function):
    """
    Calculate QK in a ring-exchange style
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length):
        # save tensor for backward
        ctx.save_for_backward(sub_q, sub_k)
        ctx.sub_seq_length = sub_seq_length

        # create local segment of attention score
        attention_score = torch.empty(batch_size * num_attention_heads,
                                      sub_seq_length,
                                      sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
                                      dtype=sub_q.dtype,
                                      device=get_current_device())

        # compute local QK^T
        part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
        local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
        local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
        start_idx = local_rank * sub_seq_length
        end_idx = (local_rank + 1) * sub_seq_length
        attention_score[:, :, start_idx:end_idx] = part_a

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
            part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
            attention_score[:, :, start_idx:end_idx] = part_a

        return attention_score

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        sub_q, sub_k, = ctx.saved_tensors
        local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
        local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)

        # calculate gradient of sub_k
        grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q)

        dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
        grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length]
        grad_k /= local_world_size

        # calculate gradient for sub_q
        grad_q = torch.zeros_like(
            sub_q,
            dtype=sub_q.dtype,
            device=get_current_device(),
        )

        # compute with local sub_k
        start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
        grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
            grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)

        grad_q /= local_world_size

        return grad_q, grad_k, None, None, None


class RingAV(torch.autograd.Function):
    """
    Calculate AV in a ring-exchange style
    """

    @staticmethod
    @custom_fwd
    def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attention_head_size, sub_seq_length):
        local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
        local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
        local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)

        sub_attention_result = torch.zeros(batch_size * num_attention_heads,
                                           sub_seq_length,
                                           attention_head_size,
                                           device=get_current_device(),
                                           dtype=attention_score.dtype)

        # save tensors for backward
        ctx.save_for_backward(attention_score, sub_v)
        ctx.sub_seq_length = sub_seq_length

        # compute local AV
        part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
        sub_attention_result += part_av

        # compute AV in ring - all - reduce style
        for i in range(local_world_size - 1):
            sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)

            # compute QK^T
            part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
            sub_attention_result += part_av
        return sub_attention_result

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
        local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
        local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
        attention_scores, sub_v = ctx.saved_tensors

        # calculate gradient of v
        grad_v = torch.matmul(attention_scores.transpose(2, 1), grad_output)
        dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
        grad_v = grad_v[:, local_start_idx:local_end_idx]
        grad_v /= local_world_size

        # calculate gradient for attention score
        grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())

        # compute with local sub_k
        grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

        # compute QK^T in ring-all-reduce style
        for i in range(local_world_size - 1):
            sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
            start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)

            # compute grad_q
            grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

        return grad_attention_score, grad_v, None, None, None, None