mirror of https://github.com/hpcaitech/ColossalAI
355 lines
15 KiB
Python
355 lines
15 KiB
Python
|
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
|
||
|
|
||
|
import math
|
||
|
import warnings
|
||
|
from typing import List, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn as nn
|
||
|
from torch.distributed import ProcessGroup
|
||
|
|
||
|
from colossalai.lazy import LazyInitContext
|
||
|
from colossalai.shardformer.layer import ParallelModule
|
||
|
|
||
|
from .gptq_op import CaiGPTQLinearOp
|
||
|
|
||
|
HAS_GPTQ_CUDA = False
|
||
|
try:
|
||
|
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||
|
gptq_cuda = GPTQBuilder().load()
|
||
|
HAS_GPTQ_CUDA = True
|
||
|
except ImportError:
|
||
|
warnings.warn('CUDA gptq is not installed')
|
||
|
HAS_GPTQ_CUDA = False
|
||
|
|
||
|
|
||
|
class CaiQuantLinear(nn.Module):
|
||
|
|
||
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||
|
super().__init__()
|
||
|
if bits not in [2, 4, 8]:
|
||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||
|
self.infeatures = infeatures
|
||
|
self.outfeatures = outfeatures
|
||
|
self.bits = bits
|
||
|
self.maxq = 2**self.bits - 1
|
||
|
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||
|
|
||
|
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||
|
self.register_buffer(
|
||
|
'qzeros',
|
||
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
||
|
self.register_buffer('scales',
|
||
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||
|
if row_split:
|
||
|
self.register_buffer(
|
||
|
'g_idx',
|
||
|
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
||
|
dtype=torch.int32))
|
||
|
else:
|
||
|
self.register_buffer('g_idx',
|
||
|
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||
|
|
||
|
if bias:
|
||
|
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||
|
else:
|
||
|
self.bias = None
|
||
|
|
||
|
self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
|
||
|
|
||
|
self.q4 = None
|
||
|
self.empty_tensor = torch.empty((1, 1), device="meta")
|
||
|
self.tp_size = tp_size
|
||
|
self.tp_rank = tp_rank
|
||
|
self.row_split = row_split
|
||
|
|
||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||
|
|
||
|
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
||
|
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||
|
|
||
|
scales = scales.t().contiguous()
|
||
|
zeros = zeros.t().contiguous()
|
||
|
scale_zeros = zeros * scales
|
||
|
half_scales = scales.clone().half()
|
||
|
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
|
||
|
self.scales = scales.clone().half()
|
||
|
if linear.bias is not None:
|
||
|
self.bias = linear.bias.clone().half()
|
||
|
|
||
|
wn = 8
|
||
|
pbits = 32
|
||
|
ptype = torch.int32
|
||
|
unsign_type = np.uint32
|
||
|
sign_type = np.int32
|
||
|
|
||
|
intweight = []
|
||
|
for idx in range(self.infeatures):
|
||
|
intweight.append(
|
||
|
torch.round(
|
||
|
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
||
|
None])
|
||
|
intweight = torch.cat(intweight, dim=1)
|
||
|
intweight = intweight.t().contiguous()
|
||
|
intweight = intweight.numpy().astype(unsign_type)
|
||
|
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
|
||
|
|
||
|
i = 0
|
||
|
row = 0
|
||
|
|
||
|
while row < qweight.shape[0]:
|
||
|
if self.bits in [2, 4, 8]:
|
||
|
for j in range(i, i + (pbits // self.bits)):
|
||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||
|
i += pbits // self.bits
|
||
|
row += 1
|
||
|
else:
|
||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||
|
qweight = qweight.astype(sign_type)
|
||
|
qweight1 = torch.from_numpy(qweight)
|
||
|
qweight1 = qweight1.contiguous() #.to("cuda")
|
||
|
self.qweight.data.copy_(qweight1)
|
||
|
|
||
|
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
||
|
zeros -= 1
|
||
|
zeros = zeros.numpy().astype(unsign_type)
|
||
|
i = 0
|
||
|
col = 0
|
||
|
while col < qzeros.shape[1]:
|
||
|
if self.bits in [2, 4, 8]:
|
||
|
for j in range(i, i + (pbits // self.bits)):
|
||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||
|
i += pbits // self.bits
|
||
|
col += 1
|
||
|
else:
|
||
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||
|
qzeros = qzeros.astype(sign_type)
|
||
|
qzeros = torch.from_numpy(qzeros)
|
||
|
qzeros = qzeros
|
||
|
self.qzeros.data.copy_(qzeros)
|
||
|
|
||
|
if torch.equal(self.g_idx.to(g_idx.device), g_idx):
|
||
|
self.g_idx = None
|
||
|
else:
|
||
|
self.g_idx = g_idx
|
||
|
|
||
|
def init_q4(self):
|
||
|
assert self.qweight.device.type == "cuda"
|
||
|
self.q4_width = self.qweight.shape[1]
|
||
|
if self.g_idx is not None:
|
||
|
if self.row_split and torch.equal(
|
||
|
self.g_idx,
|
||
|
torch.tensor(
|
||
|
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||
|
dtype=torch.int32,
|
||
|
device=self.g_idx.device)):
|
||
|
self.g_idx = None
|
||
|
elif torch.equal(
|
||
|
self.g_idx,
|
||
|
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
||
|
dtype=torch.int32,
|
||
|
device=self.g_idx.device)):
|
||
|
self.g_idx = None
|
||
|
|
||
|
if self.g_idx is not None:
|
||
|
g_idx = self.g_idx.to("cpu")
|
||
|
else:
|
||
|
g_idx = self.empty_tensor
|
||
|
|
||
|
self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
def forward(self, x):
|
||
|
outshape = x.shape[:-1] + (self.outfeatures,)
|
||
|
|
||
|
if HAS_GPTQ_CUDA and self.bits == 4:
|
||
|
|
||
|
if self.q4 is None:
|
||
|
self.init_q4()
|
||
|
|
||
|
x = x.view(-1, x.shape[-1])
|
||
|
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
|
||
|
gptq_cuda.q4_matmul(x.half(), self.q4, output)
|
||
|
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||
|
output.add_(self.bias)
|
||
|
else:
|
||
|
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||
|
bias = self.bias
|
||
|
else:
|
||
|
bias = None
|
||
|
output = self.gptq_linear(
|
||
|
x,
|
||
|
self.qweight,
|
||
|
self.scales,
|
||
|
self.qzeros,
|
||
|
g_idx=self.g_idx,
|
||
|
bias=bias,
|
||
|
)
|
||
|
return output.view(outshape)
|
||
|
|
||
|
|
||
|
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||
|
|
||
|
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||
|
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||
|
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||
|
g_idx = gptq_linear.g_idx
|
||
|
if gptq_linear.bias is not None:
|
||
|
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
|
||
|
|
||
|
cai_split_out_features = cai_linear.outfeatures // split_num
|
||
|
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||
|
|
||
|
for i in range(split_num):
|
||
|
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
||
|
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||
|
cai_split_out_features]
|
||
|
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
||
|
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
||
|
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
||
|
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||
|
cai_split_out_features]
|
||
|
if cai_linear.bias is not None:
|
||
|
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
||
|
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||
|
cai_split_out_features]
|
||
|
|
||
|
cai_linear.g_idx.copy_(g_idx)
|
||
|
|
||
|
|
||
|
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||
|
|
||
|
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||
|
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||
|
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||
|
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
|
||
|
|
||
|
cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
|
||
|
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
|
||
|
idx_split_features = cai_linear.infeatures // split_num
|
||
|
|
||
|
for i in range(split_num):
|
||
|
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
||
|
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
||
|
cai_split_in_features, :]
|
||
|
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
||
|
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||
|
zero_split_block, :]
|
||
|
cai_linear.scales[i * zero_split_block:(i + 1) *
|
||
|
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||
|
zero_split_block, :]
|
||
|
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
||
|
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
||
|
idx_split_features]
|
||
|
if cai_linear.bias is not None:
|
||
|
cai_linear.bias.copy_(gptq_linear.bias)
|
||
|
|
||
|
|
||
|
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||
|
|
||
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||
|
|
||
|
super().__init__(bits,
|
||
|
groupsize,
|
||
|
infeatures,
|
||
|
outfeatures,
|
||
|
bias,
|
||
|
tp_size=tp_size,
|
||
|
tp_rank=tp_rank,
|
||
|
row_split=row_split)
|
||
|
self.process_group = None
|
||
|
|
||
|
@staticmethod
|
||
|
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||
|
**kwargs) -> ParallelModule:
|
||
|
LazyInitContext.materialize(module)
|
||
|
# get the attributes
|
||
|
in_features = module.in_features
|
||
|
|
||
|
# ensure only one process group is passed
|
||
|
if isinstance(process_group, (list, tuple)):
|
||
|
assert len(process_group) == 1, \
|
||
|
f'Expected only one process group, got {len(process_group)}.'
|
||
|
process_group = process_group[0]
|
||
|
|
||
|
tp_size = dist.get_world_size(process_group)
|
||
|
tp_rank = dist.get_rank(process_group)
|
||
|
|
||
|
if in_features < tp_size:
|
||
|
return module
|
||
|
|
||
|
if in_features % tp_size != 0:
|
||
|
raise ValueError(
|
||
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||
|
linear_1d = RowCaiQuantLinear(module.bits,
|
||
|
module.group_size,
|
||
|
module.in_features // tp_size,
|
||
|
module.out_features,
|
||
|
module.bias is not None,
|
||
|
tp_size=tp_size,
|
||
|
tp_rank=tp_rank,
|
||
|
row_split=True)
|
||
|
linear_1d.process_group = process_group
|
||
|
|
||
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||
|
return linear_1d
|
||
|
|
||
|
def forward(self, x):
|
||
|
output = super().forward(x)
|
||
|
if self.tp_size > 1:
|
||
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||
|
if self.bias is not None:
|
||
|
output.add_(self.bias)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||
|
|
||
|
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||
|
|
||
|
super().__init__(bits,
|
||
|
groupsize,
|
||
|
infeatures,
|
||
|
outfeatures,
|
||
|
bias,
|
||
|
tp_size=tp_size,
|
||
|
tp_rank=tp_rank,
|
||
|
row_split=row_split)
|
||
|
self.process_group = None
|
||
|
|
||
|
@staticmethod
|
||
|
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||
|
**kwargs) -> ParallelModule:
|
||
|
LazyInitContext.materialize(module)
|
||
|
# get the attributes
|
||
|
in_features = module.in_features
|
||
|
|
||
|
# ensure only one process group is passed
|
||
|
if isinstance(process_group, (list, tuple)):
|
||
|
assert len(process_group) == 1, \
|
||
|
f'Expected only one process group, got {len(process_group)}.'
|
||
|
process_group = process_group[0]
|
||
|
|
||
|
tp_size = dist.get_world_size(process_group)
|
||
|
tp_rank = dist.get_rank(process_group)
|
||
|
|
||
|
if in_features < tp_size:
|
||
|
return module
|
||
|
|
||
|
if in_features % tp_size != 0:
|
||
|
raise ValueError(
|
||
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||
|
linear_1d = ColCaiQuantLinear(module.bits,
|
||
|
module.group_size,
|
||
|
module.in_features,
|
||
|
module.out_features // tp_size,
|
||
|
module.bias is not None,
|
||
|
tp_size=tp_size,
|
||
|
tp_rank=tp_rank)
|
||
|
linear_1d.process_group = process_group
|
||
|
|
||
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||
|
return linear_1d
|