import math
import time

import numpy as np
import pytest
import torch
import torch.nn as nn
import transformers
from packaging import version

try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("please install triton from https://github.com/openai/triton")

try:
    from auto_gptq.modeling._utils import autogptq_post_init
    from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
    from exllama_kernels import prepare_buffers, set_tuning_params

    from colossalai.inference.quant.gptq import CaiQuantLinear
    HAS_AUTO_GPTQ = True
except:
    HAS_AUTO_GPTQ = False
    print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")

import warnings

HAS_GPTQ_CUDA = False
try:
    from colossalai.kernel.op_builder.gptq import GPTQBuilder
    gptq_cuda = GPTQBuilder().load()
    HAS_GPTQ_CUDA = True
except ImportError:
    warnings.warn('CUDA gptq is not installed')
    HAS_GPTQ_CUDA = False

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')

max_inner_outer_dim = 1
max_input_len = 1
max_dq_buffer_size = 1
gptq_temp_dq_buffer = None
gptq_temp_state_buffer = None


def init_buffer(cai_linear, use_act_order=False):
    global max_dq_buffer_size
    global max_input_len
    global max_dq_buffer_size
    global max_inner_outer_dim
    global gptq_temp_dq_buffer
    global gptq_temp_state_buffer

    max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)

    if use_act_order:
        max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)

    if use_act_order:
        max_input_len = 4096
    # The temp_state buffer is required to reorder X in the act-order case.
    # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
    gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
                                         dtype=torch.float16,
                                         device=torch.cuda.current_device())
    gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())

    gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
    # Using the default from exllama repo here.
    matmul_recons_thd = 8
    matmul_fused_remap = False
    matmul_no_half2 = False
    gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)


@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
                    reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
def test_gptq_linear():

    infeature = 1024
    outfeature = 1024
    group_size = 128
    wbits = 4

    inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
    batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())

    device = torch.device("cuda:0")

    linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)

    linear = linear_class(
        bits=4,
        group_size=group_size,
        infeatures=infeature,
        outfeatures=outfeature,
        bias=False,
    )

    torch.manual_seed(42)

    linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
    linear.scales = linear.scales + 0.002

    linear = linear.to(device)

    cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
    cai_linear.qweight.data.copy_(linear.qweight)
    cai_linear.scales = cai_linear.scales + 0.002
    cai_linear = cai_linear.to(device)

    linear = autogptq_post_init(linear, use_act_order=False)

    max_inner_outer_dim = max(infeature, outfeature)
    max_dq_buffer_size = linear.infeatures * linear.outfeatures
    max_input_len = 2048
    buffers = {
        "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
        "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
    }

    prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])

    # Using the default from exllama repo here.
    matmul_recons_thd = 8
    matmul_fused_remap = False
    matmul_no_half2 = False
    set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

    with torch.no_grad():
        gptq_out = linear(inps)
        batch_gptq_out = linear(batch_inps)
        torch.cuda.synchronize()
        cai_out = cai_linear(inps)
        torch.cuda.synchronize()

        batch_cai_out = cai_linear(batch_inps)
        torch.cuda.synchronize()

    assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
    assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)


if __name__ == "__main__":

    test_gptq_linear()