mirror of https://github.com/hpcaitech/ColossalAI
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
import torch
|
|
|
|
from colossalai.kernel.triton import gptq_fused_linear_triton
|
|
|
|
|
|
class CaiGPTQLinearOp(torch.nn.Module):
|
|
def __init__(self, gptq_group_size, gptq_quant_bits):
|
|
super(CaiGPTQLinearOp, self).__init__()
|
|
self.group_size = gptq_group_size
|
|
self.bits = gptq_quant_bits
|
|
self.maxq = 2**self.bits - 1
|
|
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
|
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
weight_scales: torch.Tensor,
|
|
weight_zeros: torch.Tensor,
|
|
g_idx: torch.Tensor = None,
|
|
act_type=0,
|
|
bias: torch.Tensor = None,
|
|
residual: torch.Tensor = None,
|
|
qkv_fused=False,
|
|
):
|
|
add_bias = True
|
|
if bias is None:
|
|
bias = self.empty_tensor
|
|
add_bias = False
|
|
|
|
add_residual = True
|
|
if residual is None:
|
|
residual = self.empty_tensor
|
|
add_residual = False
|
|
x = input.view(-1, input.shape[-1])
|
|
|
|
out = gptq_fused_linear_triton(
|
|
x,
|
|
weight,
|
|
weight_scales,
|
|
weight_zeros,
|
|
bias,
|
|
residual,
|
|
self.bits,
|
|
self.maxq,
|
|
self.group_size,
|
|
qkv_fused,
|
|
add_bias,
|
|
add_residual,
|
|
act_type=act_type,
|
|
g_idx=g_idx,
|
|
)
|
|
if qkv_fused:
|
|
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
|
else:
|
|
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
|
|
|
return out
|