ColossalAI/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py

59 lines
1.6 KiB
Python
Raw Normal View History

import torch
from colossalai.kernel.triton import gptq_fused_linear_triton
class CaiGPTQLinearOp(torch.nn.Module):
def __init__(self, gptq_group_size, gptq_quant_bits):
super(CaiGPTQLinearOp, self).__init__()
self.group_size = gptq_group_size
self.bits = gptq_quant_bits
self.maxq = 2**self.bits - 1
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
def forward(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type=0,
bias: torch.Tensor = None,
residual: torch.Tensor = None,
qkv_fused=False,
):
add_bias = True
if bias is None:
bias = self.empty_tensor
add_bias = False
add_residual = True
if residual is None:
residual = self.empty_tensor
add_residual = False
x = input.view(-1, input.shape[-1])
out = gptq_fused_linear_triton(
x,
weight,
weight_scales,
weight_zeros,
bias,
residual,
self.bits,
self.maxq,
self.group_size,
qkv_fused,
add_bias,
add_residual,
act_type=act_type,
g_idx=g_idx,
)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
return out