mirror of https://github.com/hpcaitech/ColossalAI
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
import torch
|
|
|
|
|
|
class GPTQManager:
|
|
def __init__(self, quant_config, max_input_len: int = 1):
|
|
self.max_dq_buffer_size = 1
|
|
self.max_inner_outer_dim = 1
|
|
self.bits = quant_config.bits
|
|
self.use_act_order = quant_config.desc_act
|
|
self.max_input_len = 1
|
|
self.gptq_temp_state_buffer = None
|
|
self.gptq_temp_dq_buffer = None
|
|
self.quant_config = quant_config
|
|
|
|
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
|
|
from .cai_gptq import CaiQuantLinear
|
|
|
|
HAS_GPTQ_CUDA = False
|
|
try:
|
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
|
|
|
gptq_cuda = GPTQBuilder().load()
|
|
HAS_GPTQ_CUDA = True
|
|
except ImportError:
|
|
warnings.warn("CUDA gptq is not installed")
|
|
HAS_GPTQ_CUDA = False
|
|
|
|
for name, submodule in model.named_modules():
|
|
if isinstance(submodule, CaiQuantLinear):
|
|
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
|
|
|
if self.use_act_order:
|
|
self.max_inner_outer_dim = max(
|
|
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
|
|
)
|
|
self.bits = submodule.bits
|
|
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
|
return
|
|
|
|
max_input_len = 1
|
|
if self.use_act_order:
|
|
max_input_len = self.max_input_len
|
|
# The temp_state buffer is required to reorder X in the act-order case.
|
|
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
self.gptq_temp_state_buffer = torch.zeros(
|
|
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
|
|
)
|
|
self.gptq_temp_dq_buffer = torch.zeros(
|
|
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
|
|
)
|
|
|
|
gptq_cuda.prepare_buffers(
|
|
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
|
|
)
|
|
# Using the default from exllama repo here.
|
|
matmul_recons_thd = 8
|
|
matmul_fused_remap = False
|
|
matmul_no_half2 = False
|
|
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
|
|
torch.cuda.empty_cache()
|