mirror of https://github.com/hpcaitech/ColossalAI
[feature] add gptq for inference (#4754)
* [gptq] add gptq kernel (#4416) * add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test * [gptq] faster gptq cuda kernel (#4494) * [skip ci] add cuda kernels * add license * [skip ci] fix max_input_len * format files & change test size * [skip ci] * [gptq] add gptq tensor parallel (#4538) * add gptq tensor parallel * add gptq tp * delete print * add test gptq check * add test auto gptq check * [gptq] combine gptq and kv cache manager (#4706) * combine gptq and kv cache manager * add init bits * delete useless code * add model path * delete usless print and update test * delete usless import * move option gptq to shard config * change replace linear to shardformer * update bloom policy * delete useless code * fix import bug and delete uselss code * change colossalai/gptq to colossalai/quant/gptq * update import linear for tests * delete useless code and mv gptq_kernel to kernel directory * fix triton kernel * add triton importpull/4778/head
parent
1e0e080837
commit
946ab56c48
49
LICENSE
49
LICENSE
|
@ -428,3 +428,52 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
---------------- LICENSE FOR AutoGPTQ ----------------
|
||||
|
||||
From AutoGPTQ:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 潘其威(William)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
---------------- LICENSE FOR exllama ----------------
|
||||
|
||||
From exllama:
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .cai_gptq import HAS_AUTO_GPTQ
|
||||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
|
|
@ -0,0 +1,13 @@
|
|||
import warnings
|
||||
|
||||
HAS_AUTO_GPTQ = False
|
||||
try:
|
||||
import auto_gptq
|
||||
HAS_AUTO_GPTQ = True
|
||||
except ImportError:
|
||||
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
|
||||
HAS_AUTO_GPTQ = False
|
||||
|
||||
if HAS_AUTO_GPTQ:
|
||||
from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
|
||||
from .gptq_op import CaiGPTQLinearOp
|
|
@ -0,0 +1,354 @@
|
|||
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import ParallelModule
|
||||
|
||||
from .gptq_op import CaiGPTQLinearOp
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
|
||||
class CaiQuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||||
if row_split:
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
|
||||
dtype=torch.int32))
|
||||
else:
|
||||
self.register_buffer('g_idx',
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
|
||||
|
||||
self.q4 = None
|
||||
self.empty_tensor = torch.empty((1, 1), device="meta")
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.row_split = row_split
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
|
||||
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
|
||||
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
half_scales = scales.clone().half()
|
||||
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
wn = 8
|
||||
pbits = 32
|
||||
ptype = torch.int32
|
||||
unsign_type = np.uint32
|
||||
sign_type = np.int32
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
|
||||
None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(unsign_type)
|
||||
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
|
||||
|
||||
i = 0
|
||||
row = 0
|
||||
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (pbits // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += pbits // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
qweight = qweight.astype(sign_type)
|
||||
qweight1 = torch.from_numpy(qweight)
|
||||
qweight1 = qweight1.contiguous() #.to("cuda")
|
||||
self.qweight.data.copy_(qweight1)
|
||||
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(unsign_type)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (pbits // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += pbits // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
qzeros = qzeros.astype(sign_type)
|
||||
qzeros = torch.from_numpy(qzeros)
|
||||
qzeros = qzeros
|
||||
self.qzeros.data.copy_(qzeros)
|
||||
|
||||
if torch.equal(self.g_idx.to(g_idx.device), g_idx):
|
||||
self.g_idx = None
|
||||
else:
|
||||
self.g_idx = g_idx
|
||||
|
||||
def init_q4(self):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
self.q4_width = self.qweight.shape[1]
|
||||
if self.g_idx is not None:
|
||||
if self.row_split and torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor(
|
||||
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
self.g_idx = None
|
||||
elif torch.equal(
|
||||
self.g_idx,
|
||||
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
|
||||
dtype=torch.int32,
|
||||
device=self.g_idx.device)):
|
||||
self.g_idx = None
|
||||
|
||||
if self.g_idx is not None:
|
||||
g_idx = self.g_idx.to("cpu")
|
||||
else:
|
||||
g_idx = self.empty_tensor
|
||||
|
||||
self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, x):
|
||||
outshape = x.shape[:-1] + (self.outfeatures,)
|
||||
|
||||
if HAS_GPTQ_CUDA and self.bits == 4:
|
||||
|
||||
if self.q4 is None:
|
||||
self.init_q4()
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
|
||||
gptq_cuda.q4_matmul(x.half(), self.q4, output)
|
||||
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||||
output.add_(self.bias)
|
||||
else:
|
||||
if self.bias is not None and (not self.row_split or self.tp_size == 1):
|
||||
bias = self.bias
|
||||
else:
|
||||
bias = None
|
||||
output = self.gptq_linear(
|
||||
x,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
g_idx=self.g_idx,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
|
||||
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
g_idx = gptq_linear.g_idx
|
||||
if gptq_linear.bias is not None:
|
||||
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
|
||||
|
||||
cai_split_out_features = cai_linear.outfeatures // split_num
|
||||
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
|
||||
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
|
||||
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias[i * cai_split_out_features:(i + 1) *
|
||||
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
|
||||
cai_split_out_features]
|
||||
|
||||
cai_linear.g_idx.copy_(g_idx)
|
||||
|
||||
|
||||
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
|
||||
|
||||
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
|
||||
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
|
||||
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
|
||||
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
|
||||
|
||||
cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
|
||||
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
|
||||
idx_split_features = cai_linear.infeatures // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
|
||||
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
|
||||
cai_split_in_features, :]
|
||||
cai_linear.qzeros[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.scales[i * zero_split_block:(i + 1) *
|
||||
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
|
||||
zero_split_block, :]
|
||||
cai_linear.g_idx[i * idx_split_features:(i + 1) *
|
||||
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
|
||||
idx_split_features]
|
||||
if cai_linear.bias is not None:
|
||||
cai_linear.bias.copy_(gptq_linear.bias)
|
||||
|
||||
|
||||
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = RowCaiQuantLinear(module.bits,
|
||||
module.group_size,
|
||||
module.in_features // tp_size,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=True)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
||||
|
||||
def forward(self, x):
|
||||
output = super().forward(x)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
|
||||
|
||||
super().__init__(bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
row_split=row_split)
|
||||
self.process_group = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
||||
if in_features < tp_size:
|
||||
return module
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
linear_1d = ColCaiQuantLinear(module.bits,
|
||||
module.group_size,
|
||||
module.in_features,
|
||||
module.out_features // tp_size,
|
||||
module.bias is not None,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank)
|
||||
linear_1d.process_group = process_group
|
||||
|
||||
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
||||
return linear_1d
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
|
||||
from colossalai.kernel.triton import gptq_fused_linear_triton
|
||||
|
||||
|
||||
class CaiGPTQLinearOp(torch.nn.Module):
|
||||
def __init__(self, gptq_group_size, gptq_quant_bits):
|
||||
super(CaiGPTQLinearOp, self).__init__()
|
||||
self.group_size = gptq_group_size
|
||||
self.bits = gptq_quant_bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scales: torch.Tensor,
|
||||
weight_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor = None,
|
||||
act_type=0,
|
||||
bias: torch.Tensor = None,
|
||||
residual: torch.Tensor = None,
|
||||
qkv_fused=False,
|
||||
):
|
||||
add_bias = True
|
||||
if bias is None:
|
||||
bias = self.empty_tensor
|
||||
add_bias = False
|
||||
|
||||
add_residual = True
|
||||
if residual is None:
|
||||
residual = self.empty_tensor
|
||||
add_residual = False
|
||||
x = input.view(-1, input.shape[-1])
|
||||
|
||||
out = gptq_fused_linear_triton(
|
||||
x,
|
||||
weight,
|
||||
weight_scales,
|
||||
weight_zeros,
|
||||
bias,
|
||||
residual,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
self.group_size,
|
||||
qkv_fused,
|
||||
add_bias,
|
||||
add_residual,
|
||||
act_type=act_type,
|
||||
g_idx=g_idx,
|
||||
)
|
||||
if qkv_fused:
|
||||
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
|
||||
else:
|
||||
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
|
||||
|
||||
return out
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers import BloomForCausalLM, LlamaForCausalLM
|
||||
from transformers.generation import GenerationConfig
|
||||
|
@ -68,6 +69,13 @@ class TPInferEngine:
|
|||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
||||
self.max_dq_buffer_size = 1
|
||||
self.max_inner_outer_dim = 1
|
||||
self.gptq_temp_state_buffer = None
|
||||
self.gptq_temp_dq_buffer = None
|
||||
self.bits = -1
|
||||
self.use_act_order = False
|
||||
|
||||
self.shard_config = shard_config
|
||||
self.model = None
|
||||
# optimize the original model by sharding with ShardFormer
|
||||
|
@ -81,6 +89,50 @@ class TPInferEngine:
|
|||
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
|
||||
)
|
||||
|
||||
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, CaiQuantLinear):
|
||||
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
|
||||
|
||||
if self.use_act_order:
|
||||
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
|
||||
submodule.outfeatures)
|
||||
self.bits = submodule.bits
|
||||
if not (HAS_GPTQ_CUDA and self.bits == 4):
|
||||
return
|
||||
|
||||
max_input_len = 1
|
||||
if self.use_act_order:
|
||||
max_input_len = self.max_input_len
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
|
||||
self.gptq_temp_dq_buffer)
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _optimize_model(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Optimize the original model by sharding with ShardFormer.
|
||||
|
@ -129,6 +181,10 @@ class TPInferEngine:
|
|||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||
policy = get_autopolicy(model, inference_only=True)
|
||||
self.model, _ = shardformer.optimize(model, policy)
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
self._post_init_gptq_buffer(model)
|
||||
|
||||
self.model = self.model.cuda()
|
||||
|
||||
@property
|
||||
|
|
|
@ -3,6 +3,9 @@ from functools import partial
|
|||
import torch
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
|
||||
|
||||
from ..modeling.bloom import BloomInferenceForwards
|
||||
|
@ -35,6 +38,35 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
|||
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.inference_gptq:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 3}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1}),
|
||||
])
|
||||
# NOTE set inference mode to shard config
|
||||
self.shard_config._infer()
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ from functools import partial
|
|||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
@ -34,6 +36,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={'split_num': 1},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
const int SHUF_BLOCKSIZE_X = 256;
|
||||
const int SHUF_BLOCKSIZE_Y = 16;
|
||||
|
||||
__global__ void column_remap_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
half* __restrict__ x_new,
|
||||
const int x_width,
|
||||
const int x_height,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||
if (x_column >= x_width) return;
|
||||
//if (x_row >= x_height) return;
|
||||
|
||||
int x_stride = x_width;
|
||||
int x_idx = x_row * x_stride + x_column;
|
||||
|
||||
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||
int x_idx_end = x_row_end * x_stride + x_column;
|
||||
|
||||
int s_column = x_map[x_column];
|
||||
int s_idx = x_row * x_stride + s_column;
|
||||
|
||||
while (x_idx < x_idx_end)
|
||||
{
|
||||
x_new[x_idx] = x[s_idx];
|
||||
x_idx += x_stride;
|
||||
s_idx += x_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap columns in x to correspond to sequential group index before matmul
|
||||
//
|
||||
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||
1
|
||||
);
|
||||
|
||||
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _column_remap_cuh
|
||||
#define _column_remap_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_compat_cuh
|
||||
#define _cuda_compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -0,0 +1,75 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#define _cuda_buffers_cu
|
||||
#include "cuda_buffers.cuh"
|
||||
|
||||
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||
// __constant__ half2 q4_table[16][256];
|
||||
// half2 q4_table_host[16][256];
|
||||
// bool q4_table_init = false;
|
||||
|
||||
CudaBuffers::CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
temp_state_size(_temp_state_size),
|
||||
temp_state(_temp_state),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(_device);
|
||||
|
||||
cudaStreamCreate(&alt_stream_1);
|
||||
cudaStreamCreate(&alt_stream_2);
|
||||
cudaStreamCreate(&alt_stream_3);
|
||||
cudaEventCreate(&alt_stream_1_done);
|
||||
cudaEventCreate(&alt_stream_2_done);
|
||||
cudaEventCreate(&alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers::~CudaBuffers()
|
||||
{
|
||||
cudaStreamDestroy(alt_stream_1);
|
||||
cudaStreamDestroy(alt_stream_2);
|
||||
cudaStreamDestroy(alt_stream_3);
|
||||
cudaEventDestroy(alt_stream_1_done);
|
||||
cudaEventDestroy(alt_stream_2_done);
|
||||
cudaEventDestroy(alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index)
|
||||
{
|
||||
return g_buffers[device_index];
|
||||
}
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
)
|
||||
{
|
||||
CudaBuffers* buffers = new CudaBuffers
|
||||
(
|
||||
_device,
|
||||
_temp_state_size,
|
||||
_temp_state,
|
||||
_temp_dq
|
||||
);
|
||||
|
||||
g_buffers[_device] = buffers;
|
||||
}
|
||||
|
||||
void cleanup_buffers_cuda()
|
||||
{
|
||||
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||
{
|
||||
if (!g_buffers[i]) continue;
|
||||
delete g_buffers[i];
|
||||
g_buffers[i] = NULL;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_buffers_cuh
|
||||
#define _cuda_buffers_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
const int CUDA_MAX_DEVICES = 16;
|
||||
|
||||
// #ifndef _cuda_buffers_cu
|
||||
// extern __constant__ half2 q4_table[16][256];
|
||||
// #endif
|
||||
|
||||
class CudaBuffers
|
||||
{
|
||||
public:
|
||||
int device;
|
||||
|
||||
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||
int temp_state_size;
|
||||
half* temp_dq; // size of largest quant tensor * 8
|
||||
|
||||
cudaStream_t alt_stream_1;
|
||||
cudaStream_t alt_stream_2;
|
||||
cudaStream_t alt_stream_3;
|
||||
cudaEvent_t alt_stream_1_done;
|
||||
cudaEvent_t alt_stream_2_done;
|
||||
cudaEvent_t alt_stream_3_done;
|
||||
|
||||
CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
~CudaBuffers();
|
||||
};
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index);
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
void cleanup_buffers_cuda();
|
||||
|
||||
#endif
|
|
@ -0,0 +1,49 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _hip_compat_cuh
|
||||
#define _hip_compat_cuh
|
||||
|
||||
// Workaround for a bug in hipamd, backported from upstream.
|
||||
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
return __half_raw{
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
#define h2rcp __compat_h2rcp
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_get_stream hipblasGetStream
|
||||
#define rocblas_set_stream hipblasSetStream
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
|
||||
#endif
|
|
@ -0,0 +1,254 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "q4_matrix.cuh"
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
// buffer size used for sanity checks
|
||||
temp_state.numel(),
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _matrix_cuh
|
||||
#define _matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||
|
||||
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(*h_ptr++, v_0);
|
||||
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,260 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix.cuh"
|
||||
#include "cu_compat.cuh"
|
||||
#include "cuda_buffers.cuh"
|
||||
#if defined(USE_ROCM)
|
||||
#include "hip_compat.cuh"
|
||||
#endif
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
half*,
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint32_t*,
|
||||
bool
|
||||
);
|
||||
|
||||
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
__global__ void q4_matmul_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int dim,
|
||||
const int width,
|
||||
const int groupsize,
|
||||
const int block_size_z,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_half x_(x, height, dim);
|
||||
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||
MatrixView_q4_column w_(w, dim, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
|
||||
// Zero output
|
||||
|
||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
half2 acc = {};
|
||||
half acc_h = {};
|
||||
|
||||
if constexpr (use_groupsize)
|
||||
{
|
||||
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||
// could be slightly faster
|
||||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to block result
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
{
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||
}
|
||||
}
|
||||
|
||||
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||
{
|
||||
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
if (tuningParams->matmul_no_half2) {
|
||||
if (block_size_z % groupsize == 0) {
|
||||
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||
else return q4_matmul_kernel<false, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||
else return q4_matmul_kernel<false, false, false>;
|
||||
}
|
||||
} else {
|
||||
if (block_size_z % groupsize == 0)
|
||||
{
|
||||
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||
else return q4_matmul_kernel<true, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||
else return q4_matmul_kernel<true, false, false>;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute y = x @ w
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
|
||||
uint32_t* x_map = w->cuda_x_map;
|
||||
const half* x_mapped = x;
|
||||
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||
{
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
x_map = NULL;
|
||||
}
|
||||
|
||||
int block_size_z;
|
||||
if (w->width == 4096) block_size_z = 384; // 7B
|
||||
else if (w->width == 11008) block_size_z = 256;
|
||||
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||
else if (w->width == 13824) block_size_z = 256;
|
||||
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||
else if (w->width == 17920) block_size_z = 128;
|
||||
else block_size_z = 256;
|
||||
|
||||
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||
|
||||
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height + threads.y - 1) / threads.y,
|
||||
(dim + block_size_z - 1) / block_size_z
|
||||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
|
||||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
||||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
const float alpha = 1.0f;
|
||||
const float beta = no_zero ? 1.0f : 0.0f;
|
||||
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
#else
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matmul_cuh
|
||||
#define _q4_matmul_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include "tuning.h"
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#endif
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero = false,
|
||||
cudaStream_t alt_stream = NULL
|
||||
);
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero = false
|
||||
);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,225 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include <vector>
|
||||
#include "util.cuh"
|
||||
#include "matrix.cuh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||
|
||||
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||
|
||||
vector<Q4Matrix*> g_q4_matrices;
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m)
|
||||
{
|
||||
g_q4_matrices.push_back(m);
|
||||
}
|
||||
|
||||
void g_q4_free_matrices()
|
||||
{
|
||||
for (const auto& m : g_q4_matrices) delete m;
|
||||
g_q4_matrices.clear();
|
||||
}
|
||||
|
||||
Q4Matrix::Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
) :
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
device(_device)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_qweight = _qweight;
|
||||
cuda_qzeros = _qzeros;
|
||||
cuda_scales = _scales;
|
||||
|
||||
groupsize = height / groups;
|
||||
|
||||
if (_g_idx) make_sequential(_g_idx);
|
||||
}
|
||||
|
||||
Q4Matrix::~Q4Matrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Make sequential
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int x_map_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = x_map[x_map_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||
dim3 blocks
|
||||
(
|
||||
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
|
||||
height / 8,
|
||||
1
|
||||
);
|
||||
|
||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out, // (y)
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int width,
|
||||
const int groupsize
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||
|
||||
// Groupsize version
|
||||
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
void Q4Matrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height / 8 + threads.y - 1) / threads.y,
|
||||
1
|
||||
);
|
||||
|
||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matrix_cuh
|
||||
#define _q4_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
class Q4Matrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
uint32_t* cuda_qweight = NULL;
|
||||
uint32_t* cuda_qzeros = NULL;
|
||||
half* cuda_scales = NULL;
|
||||
uint32_t* cuda_x_map = NULL;
|
||||
|
||||
Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
);
|
||||
|
||||
~Q4Matrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
|
||||
private:
|
||||
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
};
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
|
@ -0,0 +1,13 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _tuning_h
|
||||
#define _tuning_h
|
||||
|
||||
struct ExLlamaTuning
|
||||
{
|
||||
int matmul_recons_thd;
|
||||
bool matmul_fused_remap;
|
||||
bool matmul_no_half2;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,33 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define cudaUnspecified hipErrorUnknown
|
||||
#else
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
#endif
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
#define _cuda_check(fn) \
|
||||
do { \
|
||||
{_cuda_err = fn;} \
|
||||
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||
} while(false)
|
||||
|
||||
// React to failure on return code == 0
|
||||
|
||||
#define _alloc_check(fn) \
|
||||
do { \
|
||||
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||
else _cuda_err = cudaSuccess; \
|
||||
} while(false)
|
||||
|
||||
#endif
|
|
@ -6,6 +6,7 @@ try:
|
|||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from .fused_layernorm import layer_norm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .rms_norm import rmsnorm_forward
|
||||
from .rotary_embedding_kernel import rotary_embedding_fwd
|
||||
from .softmax import softmax
|
||||
|
@ -20,6 +21,7 @@ try:
|
|||
"copy_kv_cache_to_dest",
|
||||
"rotary_embedding_fwd",
|
||||
"token_attention_fwd",
|
||||
"gptq_fused_linear_triton",
|
||||
]
|
||||
|
||||
except ImportError:
|
||||
|
|
|
@ -0,0 +1,541 @@
|
|||
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from auto_gptq.nn_modules.triton_utils import custom_autotune
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cosh(x):
|
||||
exp_x = tl.exp(x)
|
||||
return (exp_x + 1.0 / exp_x) * 0.5
|
||||
|
||||
|
||||
# a Triton implementation of the most used activations
|
||||
# See for instance http://arxiv.org/abs/1606.08415 for an overview
|
||||
|
||||
|
||||
# ReLU
|
||||
@triton.jit
|
||||
def relu(x):
|
||||
"""
|
||||
ReLU_ activation function
|
||||
|
||||
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
|
||||
"""
|
||||
return tl.where(x >= 0, x, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def squared_relu(x):
|
||||
"""
|
||||
Squared ReLU activation, as proposed in the Primer_ paper.
|
||||
|
||||
.. _Primer: https://arxiv.org/abs/2109.08668
|
||||
"""
|
||||
x_sq = x * x
|
||||
return tl.where(x > 0.0, x_sq, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def star_relu(x):
|
||||
"""
|
||||
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
|
||||
|
||||
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
|
||||
"""
|
||||
x_sq = x * x
|
||||
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
|
||||
|
||||
|
||||
# Leaky ReLU
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
"""
|
||||
LeakyReLU_ activation
|
||||
|
||||
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
|
||||
"""
|
||||
return tl.where(x >= 0.0, x, 0.01 * x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu(x):
|
||||
"""
|
||||
GeLU_ activation - Gaussian error linear unit
|
||||
|
||||
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
|
||||
"""
|
||||
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def smelu(x):
|
||||
"""
|
||||
SmeLU_ activation - Smooth ReLU with beta=2.0
|
||||
|
||||
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
|
||||
"""
|
||||
beta = 2.0
|
||||
|
||||
relu = tl.where(x >= beta, x, 0.0)
|
||||
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def silu(x):
|
||||
return x * tl.sigmoid(x)
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def cai_gptq_matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
bias_ptr,
|
||||
residual_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
gptq_group_size,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
QKV_FUSED: tl.constexpr,
|
||||
ADD_BIAS: tl.constexpr,
|
||||
ADD_RESIDUAL: tl.constexpr,
|
||||
ACT_TYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
NK = K
|
||||
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
|
||||
qkv_offset = pid // (num_pid_m * num_pid_n)
|
||||
pid = pid % (num_pid_m * num_pid_n)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
# offs_bk = offs_k + qkv_offset * NK
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ qkv_offset * N * NK // infearure_per_bits
|
||||
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
# g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
|
||||
zeros_ptrs = (
|
||||
zeros_ptr
|
||||
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
|
||||
+ (offs_bn[None, :] // infearure_per_bits)
|
||||
)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
g_idx_base = tl.arange(0, BLOCK_SIZE_K)
|
||||
g_idx_base = g_idx_base // gptq_group_size
|
||||
g_idx = g_idx_base
|
||||
# tl.device_print("gidx, ", g_idx)
|
||||
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = zeros + 1
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
# g_idx = tl.load(g_ptrs)
|
||||
# if (k + 1) * BLOCK_SIZE_K > currend_group_end:
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = zeros + 1
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros).to(tl.float16) * scales # Scale and shift
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size
|
||||
# if (k + 2) * BLOCK_SIZE_K > currend_group_end:
|
||||
|
||||
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
|
||||
if ADD_BIAS:
|
||||
bias_mask = offs_bn < N
|
||||
offs_bn += qkv_offset * N
|
||||
bias_ptrs = bias_ptr + stride_cn * offs_bn
|
||||
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
accumulator += bias[None, :]
|
||||
|
||||
if ACT_TYPE == 1:
|
||||
accumulator = relu(accumulator)
|
||||
elif ACT_TYPE == 2:
|
||||
accumulator = gelu(accumulator)
|
||||
elif ACT_TYPE == 3:
|
||||
accumulator = silu(accumulator)
|
||||
|
||||
if ADD_RESIDUAL:
|
||||
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
|
||||
accumulator += res
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def cai_gptq_idx_matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
idx_ptr,
|
||||
bias_ptr,
|
||||
residual_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
gptq_group_size,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
QKV_FUSED: tl.constexpr,
|
||||
ADD_BIAS: tl.constexpr,
|
||||
ADD_RESIDUAL: tl.constexpr,
|
||||
ACT_TYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
NK = K
|
||||
|
||||
# if QKV_FUSED:
|
||||
# NK = K//3
|
||||
# else:
|
||||
# NK = K
|
||||
# NK = K
|
||||
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
|
||||
qkv_offset = pid // (num_pid_m * num_pid_n)
|
||||
pid = pid % (num_pid_m * num_pid_n)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
# offs_bk = offs_k + qkv_offset * NK
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ qkv_offset * N * NK // infearure_per_bits
|
||||
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
# g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
|
||||
zeros_ptrs = (
|
||||
zeros_ptr
|
||||
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
|
||||
+ (offs_bn[None, :] // infearure_per_bits)
|
||||
)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
g_ptrs = idx_ptr + offs_k
|
||||
g_idx = tl.load(g_ptrs)
|
||||
# tl.device_print("gidx, ", g_idx)
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = zeros + 1
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros).to(tl.float16) * scales # Scale and shift
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
|
||||
if ADD_BIAS:
|
||||
bias_mask = offs_bn < N
|
||||
offs_bn += qkv_offset * N
|
||||
bias_ptrs = bias_ptr + stride_cn * offs_bn
|
||||
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
accumulator += bias[None, :]
|
||||
|
||||
if ACT_TYPE == 1:
|
||||
accumulator = relu(accumulator)
|
||||
elif ACT_TYPE == 2:
|
||||
accumulator = gelu(accumulator)
|
||||
elif ACT_TYPE == 3:
|
||||
accumulator = silu(accumulator)
|
||||
|
||||
if ADD_RESIDUAL:
|
||||
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
|
||||
accumulator += res
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def gptq_fused_linear_triton(
|
||||
input,
|
||||
qweight,
|
||||
scales,
|
||||
qzeros,
|
||||
bias,
|
||||
residual,
|
||||
bits,
|
||||
maxq,
|
||||
gptq_group_size,
|
||||
qkv_fused,
|
||||
add_bias,
|
||||
add_residual,
|
||||
g_idx=None,
|
||||
act_type=0,
|
||||
):
|
||||
# print("gptq fused ", qkv_fused, add_bias, add_residual)
|
||||
assert input.is_cuda, "input is not in cuda"
|
||||
assert qweight.is_cuda, "qweight is not in cuda"
|
||||
assert scales.is_cuda, "scales is not in cuda"
|
||||
assert qzeros.is_cuda, "qzeros is not in cuda"
|
||||
|
||||
with torch.cuda.device(input.device):
|
||||
if qkv_fused:
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"])
|
||||
* 3,
|
||||
)
|
||||
output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
else:
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
# print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype)
|
||||
if g_idx is None:
|
||||
cai_gptq_matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
bias,
|
||||
residual,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
gptq_group_size,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
QKV_FUSED=qkv_fused,
|
||||
ADD_BIAS=add_bias,
|
||||
ADD_RESIDUAL=add_residual,
|
||||
ACT_TYPE=act_type,
|
||||
)
|
||||
else:
|
||||
cai_gptq_idx_matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
bias,
|
||||
residual,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
gptq_group_size,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
QKV_FUSED=qkv_fused,
|
||||
ADD_BIAS=add_bias,
|
||||
ADD_RESIDUAL=add_residual,
|
||||
ACT_TYPE=act_type,
|
||||
)
|
||||
if qkv_fused:
|
||||
return output.view(3, input.shape[0], qweight.shape[1])
|
||||
else:
|
||||
return output
|
|
@ -32,10 +32,13 @@ class ShardConfig:
|
|||
enable_fused_normalization: bool = False
|
||||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
inference_only: bool = False
|
||||
inference_gptq: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
||||
@property
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
|
||||
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
# trim warmup queries
|
||||
latency_set = list(latency_set)
|
||||
latency_set = latency_set[warmup:]
|
||||
count = len(latency_set)
|
||||
|
||||
if count > 0:
|
||||
latency_set.sort()
|
||||
avg = sum(latency_set) / count
|
||||
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
|
||||
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
|
||||
num_bytes = 2 # float16
|
||||
|
||||
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
|
||||
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
|
||||
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
|
||||
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
|
||||
|
||||
|
||||
def bench_bloom(args):
|
||||
|
||||
pretrained_model_dir = args.path
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.batch_size
|
||||
max_input_len = args.input_len
|
||||
max_output_len = args.output_len
|
||||
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
||||
device=torch.cuda.current_device(),
|
||||
inject_fused_attention=False)
|
||||
|
||||
model = model.half()
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
|
||||
}
|
||||
|
||||
# init TPInferEngine and shard the original model
|
||||
# To benchmark torch original, comment out the line of optimizing model
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
inference_only=True,
|
||||
inference_gptq=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
# prepare data for generation
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len))
|
||||
}
|
||||
for t in input_tokens:
|
||||
if torch.is_tensor(input_tokens[t]):
|
||||
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
|
||||
# print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
|
||||
|
||||
iters = 10
|
||||
times = []
|
||||
for i in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
out_len = outputs.shape[1]
|
||||
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
|
||||
times.append((end - start) / (out_len - max_input_len))
|
||||
|
||||
print_perf_stats(times, model_config, max_batch_size)
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
bench_bloom(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom(args):
|
||||
spawn(check_bloom, args.tp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
|
||||
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
|
||||
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
test_bloom(args)
|
|
@ -0,0 +1,135 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
|
||||
from torch import distributed as dist
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline
|
||||
|
||||
import colossalai
|
||||
from colossalai.gptq import CaiQuantLinear
|
||||
from colossalai.gptq.gptq_tp import replace_autogptq_linear
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000):
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
|
||||
self.config.head_dim_))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
return
|
||||
|
||||
|
||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||
# trim warmup queries
|
||||
latency_set = list(latency_set)
|
||||
latency_set = latency_set[warmup:]
|
||||
count = len(latency_set)
|
||||
|
||||
if count > 0:
|
||||
latency_set.sort()
|
||||
avg = sum(latency_set) / count
|
||||
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
|
||||
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
|
||||
num_bytes = 2
|
||||
|
||||
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
|
||||
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
|
||||
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
|
||||
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
|
||||
|
||||
|
||||
def run_llama_test(args):
|
||||
pretrained_model_dir = args.path
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.batch_size
|
||||
max_input_len = args.input_len
|
||||
max_output_len = args.output_len
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
||||
device=torch.cuda.current_device(),
|
||||
inject_fused_attention=False)
|
||||
|
||||
init_to_get_rotary(model.model.model, base=10000)
|
||||
|
||||
model_config = model.config
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
|
||||
inference_only=True,
|
||||
inference_gptq=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
|
||||
}
|
||||
|
||||
iters = 10
|
||||
times = []
|
||||
|
||||
for i in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
out_len = outputs.shape[1]
|
||||
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
|
||||
times.append((end - start) / (out_len - max_input_len))
|
||||
|
||||
print_perf_stats(times, model_config, max_batch_size)
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama(args):
|
||||
spawn(check_llama, args.tp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
|
||||
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
|
||||
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
|
||||
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
test_llama(args)
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
import torch
|
||||
import re
|
||||
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
class GPTQBuilder(Builder):
|
||||
|
||||
NAME = "cu_gptq"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=GPTQBuilder.NAME,
|
||||
prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname) for fname in [
|
||||
'gptq/linear_gptq.cpp',
|
||||
'gptq/column_remap.cu',
|
||||
'gptq/cuda_buffers.cu',
|
||||
'gptq/q4_matmul.cu',
|
||||
'gptq/q4_matrix.cu'
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ['-v',
|
||||
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17"
|
||||
]
|
||||
|
||||
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 80:
|
||||
extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
|
||||
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
|
@ -18,3 +18,4 @@ SentencePiece
|
|||
ninja
|
||||
flash_attn==2.0.5
|
||||
datasets
|
||||
#auto-gptq now not support torch1.12
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from auto_gptq.modeling._utils import autogptq_post_init
|
||||
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
|
||||
from exllama_kernels import prepare_buffers, set_tuning_params
|
||||
|
||||
from colossalai.inference.quant.gptq import CaiQuantLinear
|
||||
HAS_AUTO_GPTQ = True
|
||||
except:
|
||||
HAS_AUTO_GPTQ = False
|
||||
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
|
||||
|
||||
import warnings
|
||||
|
||||
HAS_GPTQ_CUDA = False
|
||||
try:
|
||||
from colossalai.kernel.op_builder.gptq import GPTQBuilder
|
||||
gptq_cuda = GPTQBuilder().load()
|
||||
HAS_GPTQ_CUDA = True
|
||||
except ImportError:
|
||||
warnings.warn('CUDA gptq is not installed')
|
||||
HAS_GPTQ_CUDA = False
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
|
||||
|
||||
max_inner_outer_dim = 1
|
||||
max_input_len = 1
|
||||
max_dq_buffer_size = 1
|
||||
gptq_temp_dq_buffer = None
|
||||
gptq_temp_state_buffer = None
|
||||
|
||||
|
||||
def init_buffer(cai_linear, use_act_order=False):
|
||||
global max_dq_buffer_size
|
||||
global max_input_len
|
||||
global max_dq_buffer_size
|
||||
global max_inner_outer_dim
|
||||
global gptq_temp_dq_buffer
|
||||
global gptq_temp_state_buffer
|
||||
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
|
||||
|
||||
if use_act_order:
|
||||
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
|
||||
|
||||
if use_act_order:
|
||||
max_input_len = 4096
|
||||
# The temp_state buffer is required to reorder X in the act-order case.
|
||||
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
|
||||
dtype=torch.float16,
|
||||
device=torch.cuda.current_device())
|
||||
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
|
||||
|
||||
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
|
||||
def test_gptq_linear():
|
||||
|
||||
infeature = 1024
|
||||
outfeature = 1024
|
||||
group_size = 128
|
||||
wbits = 4
|
||||
|
||||
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
|
||||
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
group_size=group_size,
|
||||
infeatures=infeature,
|
||||
outfeatures=outfeature,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
|
||||
linear.scales = linear.scales + 0.002
|
||||
|
||||
linear = linear.to(device)
|
||||
|
||||
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
|
||||
cai_linear.qweight.data.copy_(linear.qweight)
|
||||
cai_linear.scales = cai_linear.scales + 0.002
|
||||
cai_linear = cai_linear.to(device)
|
||||
|
||||
linear = autogptq_post_init(linear, use_act_order=False)
|
||||
|
||||
max_inner_outer_dim = max(infeature, outfeature)
|
||||
max_dq_buffer_size = linear.infeatures * linear.outfeatures
|
||||
max_input_len = 2048
|
||||
buffers = {
|
||||
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
|
||||
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
}
|
||||
|
||||
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
|
||||
|
||||
# Using the default from exllama repo here.
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
with torch.no_grad():
|
||||
gptq_out = linear(inps)
|
||||
batch_gptq_out = linear(batch_inps)
|
||||
torch.cuda.synchronize()
|
||||
cai_out = cai_linear(inps)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
batch_cai_out = cai_linear(batch_inps)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
|
||||
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_gptq_linear()
|
Loading…
Reference in New Issue