Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

150 lines
4.9 KiB

import math
import time
import numpy as np
import pytest
import torch
import torch.nn as nn
import transformers
from packaging import version
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True
except:
HAS_AUTO_GPTQ = False
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
import warnings
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
max_inner_outer_dim = 1
max_input_len = 1
max_dq_buffer_size = 1
gptq_temp_dq_buffer = None
gptq_temp_state_buffer = None
def init_buffer(cai_linear, use_act_order=False):
global max_dq_buffer_size
global max_input_len
global max_dq_buffer_size
global max_inner_outer_dim
global gptq_temp_dq_buffer
global gptq_temp_state_buffer
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
if use_act_order:
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
if use_act_order:
max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
def test_gptq_linear():
infeature = 1024
outfeature = 1024
group_size = 128
wbits = 4
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=infeature,
outfeatures=outfeature,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.to(device)
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
cai_linear.qweight.data.copy_(linear.qweight)
cai_linear.scales = cai_linear.scales + 0.002
cai_linear = cai_linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
max_inner_outer_dim = max(infeature, outfeature)
max_dq_buffer_size = linear.infeatures * linear.outfeatures
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
with torch.no_grad():
gptq_out = linear(inps)
batch_gptq_out = linear(batch_inps)
torch.cuda.synchronize()
cai_out = cai_linear(inps)
torch.cuda.synchronize()
batch_cai_out = cai_linear(batch_inps)
torch.cuda.synchronize()
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
if __name__ == "__main__":
test_gptq_linear()