[feature] add gptq for inference (#4754)

* [gptq] add gptq kernel (#4416)

* add gptq

* refactor code

* fix tests

* replace auto-gptq

* rname inferance/quant

* refactor test

* add auto-gptq as an option

* reset requirements

* change assert and check auto-gptq

* add import warnings

* change test flash attn version

* remove example

* change requirements of flash_attn

* modify tests

* [skip ci] change requirements-test

* [gptq] faster gptq cuda kernel (#4494)

* [skip ci] add cuda kernels

* add license

* [skip ci] fix max_input_len

* format files & change test size

* [skip ci]

* [gptq] add gptq tensor parallel (#4538)

* add gptq tensor parallel

* add gptq tp

* delete print

* add test gptq check

* add test auto gptq check

* [gptq] combine gptq and kv cache manager (#4706)

* combine gptq and kv cache manager

* add init bits

* delete useless code

* add model path

* delete usless print and update test

* delete usless import

* move option gptq to shard config

* change replace linear to shardformer

* update bloom policy

* delete useless code

* fix import bug and delete uselss code

* change colossalai/gptq to colossalai/quant/gptq

* update import linear for tests

* delete useless code and mv gptq_kernel to kernel directory

* fix triton kernel

* add triton import
pull/4778/head
Xu Kai 2023-09-22 11:02:50 +08:00 committed by GitHub
parent 1e0e080837
commit 946ab56c48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 3120 additions and 2 deletions

49
LICENSE
View File

@ -428,3 +428,52 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---------------- LICENSE FOR AutoGPTQ ----------------
From AutoGPTQ:
MIT License
Copyright (c) 2023 潘其威(William)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR exllama ----------------
From exllama:
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,4 @@
from .cai_gptq import HAS_AUTO_GPTQ
if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear

View File

@ -0,0 +1,13 @@
import warnings
HAS_AUTO_GPTQ = False
try:
import auto_gptq
HAS_AUTO_GPTQ = True
except ImportError:
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
HAS_AUTO_GPTQ = False
if HAS_AUTO_GPTQ:
from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
from .gptq_op import CaiGPTQLinearOp

View File

@ -0,0 +1,354 @@
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
import math
import warnings
from typing import List, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import ParallelModule
from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False
class CaiQuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
if row_split:
self.register_buffer(
'g_idx',
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
dtype=torch.int32))
else:
self.register_buffer('g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
self.q4 = None
self.empty_tensor = torch.empty((1, 1), device="meta")
self.tp_size = tp_size
self.tp_rank = tp_rank
self.row_split = row_split
def pack(self, linear, scales, zeros, g_idx=None):
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
half_scales = scales.clone().half()
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
wn = 8
pbits = 32
ptype = torch.int32
unsign_type = np.uint32
sign_type = np.int32
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += pbits // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
self.qweight.data.copy_(qweight1)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
zeros -= 1
zeros = zeros.numpy().astype(unsign_type)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += pbits // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(sign_type)
qzeros = torch.from_numpy(qzeros)
qzeros = qzeros
self.qzeros.data.copy_(qzeros)
if torch.equal(self.g_idx.to(g_idx.device), g_idx):
self.g_idx = None
else:
self.g_idx = g_idx
def init_q4(self):
assert self.qweight.device.type == "cuda"
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
if self.g_idx is not None:
g_idx = self.g_idx.to("cpu")
else:
g_idx = self.empty_tensor
self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
torch.cuda.synchronize()
def forward(self, x):
outshape = x.shape[:-1] + (self.outfeatures,)
if HAS_GPTQ_CUDA and self.bits == 4:
if self.q4 is None:
self.init_q4()
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
gptq_cuda.q4_matmul(x.half(), self.q4, output)
if self.bias is not None and (not self.row_split or self.tp_size == 1):
output.add_(self.bias)
else:
if self.bias is not None and (not self.row_split or self.tp_size == 1):
bias = self.bias
else:
bias = None
output = self.gptq_linear(
x,
self.qweight,
self.scales,
self.qzeros,
g_idx=self.g_idx,
bias=bias,
)
return output.view(outshape)
def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
g_idx = gptq_linear.g_idx
if gptq_linear.bias is not None:
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
cai_split_out_features = cai_linear.outfeatures // split_num
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) *
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.g_idx.copy_(g_idx)
def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
idx_split_features = cai_linear.infeatures // split_num
for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
cai_split_in_features, :]
cai_linear.qzeros[i * zero_split_block:(i + 1) *
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.scales[i * zero_split_block:(i + 1) *
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.g_idx[i * idx_split_features:(i + 1) *
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
idx_split_features]
if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias)
class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = RowCaiQuantLinear(module.bits,
module.group_size,
module.in_features // tp_size,
module.out_features,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
linear_1d.process_group = process_group
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
if self.bias is not None:
output.add_(self.bias)
return output
class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__(bits,
groupsize,
infeatures,
outfeatures,
bias,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=row_split)
self.process_group = None
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = ColCaiQuantLinear(module.bits,
module.group_size,
module.in_features,
module.out_features // tp_size,
module.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank)
linear_1d.process_group = process_group
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d

View File

@ -0,0 +1,58 @@
import torch
from colossalai.kernel.triton import gptq_fused_linear_triton
class CaiGPTQLinearOp(torch.nn.Module):
def __init__(self, gptq_group_size, gptq_quant_bits):
super(CaiGPTQLinearOp, self).__init__()
self.group_size = gptq_group_size
self.bits = gptq_quant_bits
self.maxq = 2**self.bits - 1
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
def forward(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type=0,
bias: torch.Tensor = None,
residual: torch.Tensor = None,
qkv_fused=False,
):
add_bias = True
if bias is None:
bias = self.empty_tensor
add_bias = False
add_residual = True
if residual is None:
residual = self.empty_tensor
add_residual = False
x = input.view(-1, input.shape[-1])
out = gptq_fused_linear_triton(
x,
weight,
weight_scales,
weight_zeros,
bias,
residual,
self.bits,
self.maxq,
self.group_size,
qkv_fused,
add_bias,
add_residual,
act_type=act_type,
g_idx=g_idx,
)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
return out

View File

@ -1,6 +1,7 @@
from typing import Any, Callable, List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig
@ -68,6 +69,13 @@ class TPInferEngine:
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
self.max_dq_buffer_size = 1
self.max_inner_outer_dim = 1
self.gptq_temp_state_buffer = None
self.gptq_temp_dq_buffer = None
self.bits = -1
self.use_act_order = False
self.shard_config = shard_config
self.model = None
# optimize the original model by sharding with ShardFormer
@ -81,6 +89,50 @@ class TPInferEngine:
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
submodule.outfeatures)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return
max_input_len = 1
if self.use_act_order:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
self.gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
def _optimize_model(self, model: nn.Module) -> None:
"""
Optimize the original model by sharding with ShardFormer.
@ -129,6 +181,10 @@ class TPInferEngine:
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq:
self._post_init_gptq_buffer(model)
self.model = self.model.cuda()
@property

View File

@ -3,6 +3,9 @@ from functools import partial
import torch
from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
@ -35,6 +38,35 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
])
# NOTE set inference mode to shard config
self.shard_config._infer()

View File

@ -3,6 +3,8 @@ from functools import partial
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.shardformer.layer import VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@ -34,6 +36,55 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1},
)
],
)
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward

View File

@ -0,0 +1,63 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View File

@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View File

@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View File

@ -0,0 +1,75 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View File

@ -0,0 +1,55 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View File

@ -0,0 +1,49 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View File

@ -0,0 +1,254 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}

View File

@ -0,0 +1,294 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View File

@ -0,0 +1,260 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "util.cuh"
#include "matrix.cuh"
#include "cu_compat.cuh"
#include "cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}

View File

@ -0,0 +1,43 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View File

@ -0,0 +1,225 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "util.cuh"
#include "matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
height / 8,
1
);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View File

@ -0,0 +1,53 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View File

@ -0,0 +1,13 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View File

@ -0,0 +1,33 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View File

@ -6,6 +6,7 @@ try:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
@ -20,6 +21,7 @@ try:
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
]
except ImportError:

View File

@ -0,0 +1,541 @@
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
import torch
import triton
import triton.language as tl
from auto_gptq.nn_modules.triton_utils import custom_autotune
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
return tl.where(x >= 0, x, 0.0)
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_sq = x * x
return tl.where(x > 0.0, x_sq, 0.0)
@triton.jit
def star_relu(x):
"""
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
"""
x_sq = x * x
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
return tl.where(x >= 0.0, x, 0.01 * x)
@triton.jit
def gelu(x):
"""
GeLU_ activation - Gaussian error linear unit
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
@triton.jit
def smelu(x):
"""
SmeLU_ activation - Smooth ReLU with beta=2.0
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
beta = 2.0
relu = tl.where(x >= beta, x, 0.0)
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
@custom_autotune.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def cai_gptq_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
bias_ptr,
residual_ptr,
M,
N,
K,
bits,
maxq,
gptq_group_size,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
QKV_FUSED: tl.constexpr,
ADD_BIAS: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
ACT_TYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
NK = K
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
qkv_offset = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# offs_bk = offs_k + qkv_offset * NK
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = (
b_ptr
+ qkv_offset * N * NK // infearure_per_bits
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
# g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
zeros_ptrs = (
zeros_ptr
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
+ (offs_bn[None, :] // infearure_per_bits)
)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
g_idx_base = tl.arange(0, BLOCK_SIZE_K)
g_idx_base = g_idx_base // gptq_group_size
g_idx = g_idx_base
# tl.device_print("gidx, ", g_idx)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
for k in range(0, num_pid_k):
# g_idx = tl.load(g_ptrs)
# if (k + 1) * BLOCK_SIZE_K > currend_group_end:
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros).to(tl.float16) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size
# if (k + 2) * BLOCK_SIZE_K > currend_group_end:
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
if ADD_BIAS:
bias_mask = offs_bn < N
offs_bn += qkv_offset * N
bias_ptrs = bias_ptr + stride_cn * offs_bn
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
accumulator += bias[None, :]
if ACT_TYPE == 1:
accumulator = relu(accumulator)
elif ACT_TYPE == 2:
accumulator = gelu(accumulator)
elif ACT_TYPE == 3:
accumulator = silu(accumulator)
if ADD_RESIDUAL:
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
accumulator += res
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def cai_gptq_idx_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
idx_ptr,
bias_ptr,
residual_ptr,
M,
N,
K,
bits,
maxq,
gptq_group_size,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
QKV_FUSED: tl.constexpr,
ADD_BIAS: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
ACT_TYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
NK = K
# if QKV_FUSED:
# NK = K//3
# else:
# NK = K
# NK = K
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
qkv_offset = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# offs_bk = offs_k + qkv_offset * NK
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = (
b_ptr
+ qkv_offset * N * NK // infearure_per_bits
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
# g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
zeros_ptrs = (
zeros_ptr
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
+ (offs_bn[None, :] // infearure_per_bits)
)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
g_ptrs = idx_ptr + offs_k
g_idx = tl.load(g_ptrs)
# tl.device_print("gidx, ", g_idx)
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros).to(tl.float16) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
if ADD_BIAS:
bias_mask = offs_bn < N
offs_bn += qkv_offset * N
bias_ptrs = bias_ptr + stride_cn * offs_bn
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
accumulator += bias[None, :]
if ACT_TYPE == 1:
accumulator = relu(accumulator)
elif ACT_TYPE == 2:
accumulator = gelu(accumulator)
elif ACT_TYPE == 3:
accumulator = silu(accumulator)
if ADD_RESIDUAL:
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
accumulator += res
tl.store(c_ptrs, accumulator, mask=c_mask)
def gptq_fused_linear_triton(
input,
qweight,
scales,
qzeros,
bias,
residual,
bits,
maxq,
gptq_group_size,
qkv_fused,
add_bias,
add_residual,
g_idx=None,
act_type=0,
):
# print("gptq fused ", qkv_fused, add_bias, add_residual)
assert input.is_cuda, "input is not in cuda"
assert qweight.is_cuda, "qweight is not in cuda"
assert scales.is_cuda, "scales is not in cuda"
assert qzeros.is_cuda, "qzeros is not in cuda"
with torch.cuda.device(input.device):
if qkv_fused:
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"])
* 3,
)
output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16)
else:
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
# print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype)
if g_idx is None:
cai_gptq_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
bias,
residual,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
gptq_group_size,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
QKV_FUSED=qkv_fused,
ADD_BIAS=add_bias,
ADD_RESIDUAL=add_residual,
ACT_TYPE=act_type,
)
else:
cai_gptq_idx_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
bias,
residual,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
gptq_group_size,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
QKV_FUSED=qkv_fused,
ADD_BIAS=add_bias,
ADD_RESIDUAL=add_residual,
ACT_TYPE=act_type,
)
if qkv_fused:
return output.view(3, input.shape[0], qweight.shape[1])
else:
return output

View File

@ -32,10 +32,13 @@ class ShardConfig:
enable_fused_normalization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@property

View File

@ -0,0 +1,123 @@
import argparse
import logging
import os
import time
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
num_bytes = 2 # float16
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
def bench_bloom(args):
pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len
tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir)
tokenizer.pad_token = tokenizer.eos_token
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
device=torch.cuda.current_device(),
inject_fused_attention=False)
model = model.half()
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
}
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
inference_only=True,
inference_gptq=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len))
}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
# print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
iters = 10
times = []
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - max_input_len))
print_perf_stats(times, model_config, max_batch_size)
def check_bloom(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
bench_bloom(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom(args):
spawn(check_bloom, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
args = parser.parse_args()
test_bloom(args)

View File

@ -0,0 +1,135 @@
import argparse
import logging
import os
import time
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
from torch import distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline
import colossalai
from colossalai.gptq import CaiQuantLinear
from colossalai.gptq.gptq_tp import replace_autogptq_linear
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
num_bytes = 2
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
def run_llama_test(args):
pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
device=torch.cuda.current_device(),
inject_fused_attention=False)
init_to_get_rotary(model.model.model, base=10000)
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
inference_only=True,
inference_gptq=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
}
iters = 10
times = []
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - max_input_len))
print_perf_stats(times, model_config, max_batch_size)
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama(args):
spawn(check_llama, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
args = parser.parse_args()
test_llama(args)

52
op_builder/gptq.py Normal file
View File

@ -0,0 +1,52 @@
import os
import torch
import re
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class GPTQBuilder(Builder):
NAME = "cu_gptq"
PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq"
def __init__(self):
super().__init__(name=GPTQBuilder.NAME,
prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [
self.csrc_abs_path(fname) for fname in [
'gptq/linear_gptq.cpp',
'gptq/column_remap.cu',
'gptq/cuda_buffers.cu',
'gptq/q4_matmul.cu',
'gptq/q4_matrix.cu'
]
]
return ret
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ['-v',
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17"
]
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 80:
extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -18,3 +18,4 @@ SentencePiece
ninja
flash_attn==2.0.5
datasets
#auto-gptq now not support torch1.12

View File

@ -0,0 +1,150 @@
import math
import time
import numpy as np
import pytest
import torch
import torch.nn as nn
import transformers
from packaging import version
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True
except:
HAS_AUTO_GPTQ = False
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
import warnings
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
max_inner_outer_dim = 1
max_input_len = 1
max_dq_buffer_size = 1
gptq_temp_dq_buffer = None
gptq_temp_state_buffer = None
def init_buffer(cai_linear, use_act_order=False):
global max_dq_buffer_size
global max_input_len
global max_dq_buffer_size
global max_inner_outer_dim
global gptq_temp_dq_buffer
global gptq_temp_state_buffer
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
if use_act_order:
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
if use_act_order:
max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
def test_gptq_linear():
infeature = 1024
outfeature = 1024
group_size = 128
wbits = 4
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=infeature,
outfeatures=outfeature,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.to(device)
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
cai_linear.qweight.data.copy_(linear.qweight)
cai_linear.scales = cai_linear.scales + 0.002
cai_linear = cai_linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
max_inner_outer_dim = max(infeature, outfeature)
max_dq_buffer_size = linear.infeatures * linear.outfeatures
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
with torch.no_grad():
gptq_out = linear(inps)
batch_gptq_out = linear(batch_inps)
torch.cuda.synchronize()
cai_out = cai_linear(inps)
torch.cuda.synchronize()
batch_cai_out = cai_linear(batch_inps)
torch.cuda.synchronize()
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
if __name__ == "__main__":
test_gptq_linear()