mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
145 lines
4.7 KiB
145 lines
4.7 KiB
import pytest
|
|
import torch
|
|
from packaging import version
|
|
|
|
try:
|
|
HAS_TRITON = True
|
|
except ImportError:
|
|
HAS_TRITON = False
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
try:
|
|
from auto_gptq.modeling._utils import autogptq_post_init
|
|
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
|
|
from exllama_kernels import prepare_buffers, set_tuning_params
|
|
|
|
from colossalai.inference.quant.gptq import CaiQuantLinear
|
|
|
|
HAS_AUTO_GPTQ = True
|
|
except:
|
|
HAS_AUTO_GPTQ = False
|
|
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
|
|
|
|
import warnings
|
|
|
|
HAS_GPTQ_CUDA = False
|
|
try:
|
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
|
|
|
gptq_cuda = GPTQBuilder().load()
|
|
HAS_GPTQ_CUDA = True
|
|
except ImportError:
|
|
warnings.warn("CUDA gptq is not installed")
|
|
HAS_GPTQ_CUDA = False
|
|
|
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
|
|
max_inner_outer_dim = 1
|
|
max_input_len = 1
|
|
max_dq_buffer_size = 1
|
|
gptq_temp_dq_buffer = None
|
|
gptq_temp_state_buffer = None
|
|
|
|
|
|
def init_buffer(cai_linear, use_act_order=False):
|
|
global max_dq_buffer_size
|
|
global max_input_len
|
|
global max_dq_buffer_size
|
|
global max_inner_outer_dim
|
|
global gptq_temp_dq_buffer
|
|
global gptq_temp_state_buffer
|
|
|
|
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
|
|
|
|
if use_act_order:
|
|
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
|
|
|
|
if use_act_order:
|
|
max_input_len = 4096
|
|
# The temp_state buffer is required to reorder X in the act-order case.
|
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
gptq_temp_state_buffer = torch.zeros(
|
|
(max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
|
)
|
|
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
|
|
|
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
|
# Using the default from exllama repo here.
|
|
matmul_recons_thd = 8
|
|
matmul_fused_remap = False
|
|
matmul_no_half2 = False
|
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
|
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq",
|
|
)
|
|
def test_gptq_linear():
|
|
infeature = 1024
|
|
outfeature = 1024
|
|
group_size = 128
|
|
wbits = 4
|
|
|
|
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
|
|
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
|
|
|
|
linear = linear_class(
|
|
bits=4,
|
|
group_size=group_size,
|
|
infeatures=infeature,
|
|
outfeatures=outfeature,
|
|
bias=False,
|
|
)
|
|
|
|
torch.manual_seed(42)
|
|
|
|
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
|
|
linear.scales = linear.scales + 0.002
|
|
|
|
linear = linear.to(device)
|
|
|
|
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
|
|
cai_linear.qweight.data.copy_(linear.qweight)
|
|
cai_linear.scales = cai_linear.scales + 0.002
|
|
cai_linear = cai_linear.to(device)
|
|
|
|
linear = autogptq_post_init(linear, use_act_order=False)
|
|
|
|
max_inner_outer_dim = max(infeature, outfeature)
|
|
max_dq_buffer_size = linear.infeatures * linear.outfeatures
|
|
max_input_len = 2048
|
|
buffers = {
|
|
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
|
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device),
|
|
}
|
|
|
|
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
|
|
|
# Using the default from exllama repo here.
|
|
matmul_recons_thd = 8
|
|
matmul_fused_remap = False
|
|
matmul_no_half2 = False
|
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
|
|
with torch.no_grad():
|
|
gptq_out = linear(inps)
|
|
batch_gptq_out = linear(batch_inps)
|
|
torch.cuda.synchronize()
|
|
cai_out = cai_linear(inps)
|
|
torch.cuda.synchronize()
|
|
|
|
batch_cai_out = cai_linear(batch_inps)
|
|
torch.cuda.synchronize()
|
|
|
|
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
|
|
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_gptq_linear()
|