[inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license
pull/4918/head
Xu Kai 2023-10-16 11:28:44 +08:00 committed by GitHub
parent a0684e7bd6
commit 611a5a80ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 2962 additions and 0 deletions

50
LICENSE
View File

@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR torch-int ----------------
MIT License
Copyright (c) 2022 Guangxuan Xiao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR smoothquant ----------------
MIT License
Copyright (c) 2022 MIT HAN Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,12 @@
try:
import torch_int
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
raise ImportError(
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)
if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP

View File

@ -0,0 +1,482 @@
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
import os
import warnings
from abc import abstractmethod
from functools import partial
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union
import accelerate
import numpy as np
import torch
import torch.nn as nn
import transformers
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
SUPPORTED_MODELS = ["llama"]
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__()
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.config = self.model.config
self.cache_manager = None
self.max_total_token_num = 0
@property
def quantized(self):
return self._quantized
def init_cache_manager(self, max_total_token_num=2048):
if self.config.model_type == "llama":
head_num = self.config.num_key_value_heads
layer_num = self.config.num_hidden_layers
head_dim = self.config.hidden_size // head_num
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
self.max_total_token_num = max_total_token_num
def init_batch_state(self, max_output_len=256, **kwargs):
input_ids = kwargs["input_ids"]
batch_size = len(input_ids)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
for i in range(batch_size):
seq_len = len(input_ids[i])
seq_lengths[i] = seq_len
seq_start_indexes[i] = start_index
start_index += seq_len
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
if "max_total_token_num" in kwargs.keys():
max_total_token_num = kwargs["max_total_token_num"]
self.init_cache_manager(max_total_token_num)
if "max_new_tokens" in kwargs.keys():
max_output_len = kwargs["max_new_tokens"]
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
warnings.warn(f"reset max tokens to {max_total_token_num}")
self.init_cache_manager(max_total_token_num)
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()
return batch_infer_state
@abstractmethod
@torch.inference_mode()
def quantize(
self,
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, **kwargs):
"""shortcut for model.generate"""
batch_infer_state = self.init_batch_state(**kwargs)
if self.config.model_type == "llama":
setattr(self.model.model, "infer_state", batch_infer_state)
with torch.inference_mode():
return self.model.generate(**kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
for text in tqdm(dataset):
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
pbar = tqdm(dataset)
for text in pbar:
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}
def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
for h in hooks:
h.remove()
return act_scales
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
ln.weight.div_(scales)
if hasattr(ln, "bias"):
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
@classmethod
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")
def save_quantized(
self,
save_dir: str,
model_basename: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to("cpu")
model_base_name = model_basename # or f"smooth-"
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
print(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
print(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
print(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else:
model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
def save_pretrained(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs,
):
"""alias of save_quantized"""
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: torch.dtype = torch.float16,
**model_init_kwargs,
):
if not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs.pop("cache_dir", None)
force_download = model_init_kwargs.pop("force_download", False)
resume_download = model_init_kwargs.pop("resume_download", False)
proxies = model_init_kwargs.pop("proxies", None)
local_files_only = model_init_kwargs.pop("local_files_only", False)
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
revision = model_init_kwargs.pop("revision", None)
subfolder = model_init_kwargs.pop("subfolder", "")
model_init_kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
}
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["torch_dtype"] = torch_dtype
model_init_kwargs["trust_remote_code"] = trust_remote_code
if max_memory:
if "disk" in max_memory:
raise NotImplementedError("disk offload not support yet.")
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False,
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
)
model_init_kwargs["low_cpu_mem_usage"] = True
del model
else:
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False
torch.cuda.empty_cache()
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
return cls(model, False)
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
model_basename: Optional[str] = None,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
torch_dtype: Optional[torch.dtype] = None,
use_safetensors: bool = False,
trust_remote_code: bool = False,
**kwargs,
):
"""load quantized model from local disk"""
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
extensions = []
if use_safetensors:
extensions.append(".safetensors")
else:
extensions += [".bin", ".pt"]
model_name_or_path = str(model_name_or_path)
is_local = isdir(model_name_or_path)
resolved_archive_file = None
if is_local:
model_save_name = join(model_name_or_path, model_basename)
for ext in extensions:
if isfile(model_save_name + ext):
resolved_archive_file = model_save_name + ext
break
else: # remote
for ext in extensions:
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
if resolved_archive_file is not None:
break
if resolved_archive_file is None: # Could not find a model file to use
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
model_save_name = resolved_archive_file
# == step2: convert model to quantized-model (replace Linear) == #
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
init_contexts = [no_init_weights()]
if low_cpu_mem_usage:
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
with ContextManagers(init_contexts):
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
)
cls.create_quantized_model(model)
model.tie_weights()
# == step3: load checkpoint to quantized-model == #
accelerate.utils.modeling.load_checkpoint_in_model(
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
)
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
return cls(
model,
True,
)
def __getattr__(self, item):
try:
return super().__getattr__(item)
except:
return getattr(self.model, item)
__all__ = ["BaseSmoothForCausalLM"]

View File

@ -0,0 +1,177 @@
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
if module.bias is not None:
int8_module.bias.data.copy_(module.bias.to(torch.float))
int8_module.a = alpha
return int8_module
class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
self.register_buffer("b", torch.tensor(beta))
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale, output_scale):
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale / output_scale
int8_module.weight = int8_weight
int8_module.a = alpha
if module.bias is not None:
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
int8_module.bias = int8_bias
beta = bias_scale / output_scale
int8_module.b = beta
return int8_module
class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(self.out_features, self.in_features),
dtype=torch.int8,
requires_grad=False,
),
)
self.register_buffer(
"bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self
@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
y = y.view(*x_shape[:-1], -1)
return y
@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.a = alpha
int8_module.input_scale = input_scale
int8_module.weight_scale = weight_scale
if module.bias is not None:
int8_module.bias = module.bias.to(torch.float32)
return int8_module

View File

@ -0,0 +1,846 @@
import math
import os
import types
from collections import defaultdict
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LLAMA_INPUTS_DOCSTRING,
LlamaAttention,
LlamaDecoderLayer,
LlamaMLP,
LlamaRotaryEmbedding,
repeat_kv,
rotate_half,
)
from transformers.utils import add_start_docstrings_to_model_forward
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
copy_kv_cache_to_dest,
int8_rotary_embedding_fwd,
smooth_llama_context_attn_fwd,
smooth_token_attention_fwd,
)
from .base_model import BaseSmoothForCausalLM
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
class LLamaSmoothquantAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
self.register_buffer("q_output_scale", torch.tensor([1.0]))
self.register_buffer("k_output_scale", torch.tensor([1.0]))
self.register_buffer("v_output_scale", torch.tensor([1.0]))
self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
self.register_buffer("out_input_scale", torch.tensor([1.0]))
self.register_buffer("attn_input_scale", torch.tensor([1.0]))
self._init_rope()
self.num_key_value_heads = num_heads
def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=2048,
base=10000.0,
)
@staticmethod
def pack(
module: LlamaAttention,
attn_input_scale: float,
q_output_scale: float,
k_output_scale: float,
v_output_scale: float,
q_rotary_output_scale: float,
k_rotary_output_scale: float,
out_input_scale: float,
):
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
int8_module.attn_input_scale = torch.tensor([attn_input_scale])
int8_module.q_output_scale = torch.tensor([q_output_scale])
int8_module.k_output_scale = torch.tensor([k_output_scale])
int8_module.v_output_scale = torch.tensor([v_output_scale])
int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
int8_module.out_input_scale = torch.tensor([out_input_scale])
return int8_module
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
cos = rotary_emb[0]
sin = rotary_emb[1]
int8_rotary_embedding_fwd(
query_states.view(-1, self.num_heads, self.head_dim),
cos,
sin,
self.q_output_scale.item(),
self.q_rotary_output_scale.item(),
)
int8_rotary_embedding_fwd(
key_states.view(-1, self.num_heads, self.head_dim),
cos,
sin,
self.k_output_scale.item(),
self.k_rotary_output_scale.item(),
)
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
if infer_state.is_context_stage:
# first token generation
# copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
smooth_llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
self.q_rotary_output_scale.item(),
self.k_rotary_output_scale.item(),
self.v_output_scale.item(),
self.out_input_scale.item(),
infer_state.start_loc,
infer_state.seq_len,
q_len,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
smooth_token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
self.q_rotary_output_scale.item(),
self.k_rotary_output_scale.item(),
self.v_output_scale.item(),
self.out_input_scale.item(),
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, None
class LlamaLayerNormQ(torch.nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.input_scale = 1.0
self.variance_epsilon = eps
self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
def forward(self, x):
ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
return ln_output_int8
@staticmethod
def from_float(module: torch.nn.LayerNorm, output_scale: float):
assert module.weight.shape[0] == module.weight.numel()
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
q_module.weight = module.weight / output_scale
return q_module
class LlamaSmoothquantMLP(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
@staticmethod
def pack(
mlp_module: LlamaMLP,
gate_proj_input_scale: float,
up_proj_input_scale: float,
down_proj_input_scale: float,
):
int8_module = LlamaSmoothquantMLP(
mlp_module.intermediate_size,
mlp_module.hidden_size,
)
int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
return int8_module
def forward(
self,
hidden_states: torch.Tensor,
):
x_shape = hidden_states.shape
gate_out = self.gate_proj(hidden_states)
up_out = self.up_proj(hidden_states)
inter_out = gate_out * up_out
inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
down_out = self.down_proj(inter_out)
down_out = down_out.view(*x_shape[:-1], -1)
return down_out
class LlamaSmoothquantDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
@staticmethod
def pack(
module: LlamaDecoderLayer,
attn_input_scale: float,
q_output_scale: float,
k_output_scale: float,
v_output_scale: float,
q_rotary_output_scale: float,
k_rotary_output_scale: float,
out_input_scale: float,
gate_input_scale: float,
up_input_scale: float,
down_input_scale: float,
):
config = module.self_attn.config
int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
module.self_attn,
attn_input_scale,
q_output_scale,
k_output_scale,
v_output_scale,
q_rotary_output_scale,
k_rotary_output_scale,
out_input_scale,
)
int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
module.post_attention_layernorm, gate_input_scale
)
int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
module.mlp,
gate_input_scale,
up_input_scale,
down_input_scale,
)
return int8_decoder_layer
def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
rotary_emb=rotary_emb,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
infer_state=infer_state,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None, None
class LlamaApplyRotary(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
def llama_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def init_to_get_rotary(config, base=10000, use_elem=False):
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
base : calculation arg
use_elem : activated when using chatglm-based models
"""
config.head_dim_ = config.hidden_size // config.num_attention_heads
if not hasattr(config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
if hasattr(config, "max_sequence_length"):
max_seq_len = config.max_sequence_length
elif hasattr(config, "max_position_embeddings"):
max_seq_len = config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
try:
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
except:
pass
n_elem = config.head_dim_
if use_elem:
n_elem //= 2
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
_cos_cached = torch.cos(freqs).to(torch.float)
_sin_cached = torch.sin(freqs).to(torch.float)
return _cos_cached, _sin_cached
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
infer_state = self.infer_state
if past_key_values is not None:
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if infer_state.is_context_stage:
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
raise NotImplementedError("not implement gradient_checkpointing and training options ")
if past_key_values_length == 0:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
infer_state=infer_state,
)
hidden_states = layer_outputs[0]
infer_state.decode_layer_id += 1
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
infer_state.is_context_stage = False
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
layer_type = "LlamaDecoderLayer"
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__(model, quantized)
def get_act_dict(
self,
tokenizer,
dataset,
num_samples=512,
seq_len=512,
):
llama_model = self.model
llama_model.eval()
device = next(llama_model.parameters()).device
# print("model:", llama_model)
act_dict = defaultdict(dict)
def stat_io_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
if name not in act_dict or "input" not in act_dict[name]:
act_dict[name]["input"] = x.detach().abs().max().item()
else:
act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
if isinstance(y, tuple):
y = y[0]
if name not in act_dict or "output" not in act_dict[name]:
act_dict[name]["output"] = y.detach().abs().max().item()
else:
act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
for name, m in llama_model.named_modules():
if isinstance(m, LlamaAttention):
setattr(m, "q_apply_rotary", LlamaApplyRotary())
setattr(m, "k_apply_rotary", LlamaApplyRotary())
m.forward = types.MethodType(llama_decoder_layer_forward, m)
hooks = []
for name, m in llama_model.named_modules():
if isinstance(m, LlamaApplyRotary):
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
if isinstance(m, torch.nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
for hook in hooks:
hook.remove()
return act_dict
def smooth_fn(self, scales, alpha=0.5):
model = self.model
for name, module in model.named_modules():
if isinstance(module, LlamaDecoderLayer):
attn_ln = module.input_layernorm
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
qkv_input_scales = scales[name + ".self_attn.q_proj"]
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
def create_quantized_model(model):
llama_config = model.config
for i, layer in enumerate(model.model.layers):
model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
model.model.forward = types.MethodType(llama_model_forward, model.model)
cos, sin = init_to_get_rotary(llama_config)
model.model.register_buffer("_cos_cached", cos)
model.model.register_buffer("_sin_cached", sin)
def quantized(
self,
tokenizer,
dataset,
num_samples=512,
seq_len=512,
alpha=0.5,
):
llama_model = self.model
llama_config = llama_model.config
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
self.smooth_fn(act_scales, alpha)
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
decoder_layer_scales = []
for idx in range(llama_config.num_hidden_layers):
scale_dict = {}
scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
scale_dict["q_rotary_output_scale"] = (
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
)
scale_dict["k_rotary_output_scale"] = (
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
)
scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
decoder_layer_scales.append(scale_dict)
for i, layer in enumerate(llama_model.model.layers):
orig_layer = layer
llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
cos, sin = init_to_get_rotary(llama_config)
llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))

View File

@ -0,0 +1,8 @@
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}

View File

@ -0,0 +1,162 @@
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}

View File

@ -0,0 +1,12 @@
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);

View File

@ -13,8 +13,10 @@ if HAS_TRITON:
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd
@ -29,4 +31,7 @@ if HAS_TRITON:
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
"smooth_llama_context_attn_fwd",
"smooth_token_attention_fwd",
]

View File

@ -0,0 +1,117 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl
@triton.jit
def _rotary_kernel(
q,
input_scale,
output_scale,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
q0 = q0.to(tl.float32) * input_scale
q1 = q1.to(tl.float32) * input_scale
out0 = (q0 * cos - q1 * sin) / output_scale
out1 = (q0 * sin + q1 * cos) / output_scale
out0 = out0.to(tl.int8)
out1 = out1.to(tl.int8)
tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return
@torch.no_grad()
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4
_rotary_kernel[grid](
q,
input_scale,
output_scale,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return

View File

@ -0,0 +1,652 @@
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
@triton.jit
def _context_flash_attention_kernel(
Q,
K,
V,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if alibi_ptr is not None:
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
qk -= alibi_loc * alibi_m
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
v = v.to(tl.float16) * v_input_scale.to(tl.float16)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def smooth_llama_context_attn_fwd(
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk, "context process only supports equal query, key, value length"
assert Lk == Lv, "context process only supports equal query, key, value length"
assert Lk in {16, 32, 64, 128}
BLOCK_N = 128
sm_scale = 1.0 / math.sqrt(Lk)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_context_flash_attention_kernel[grid](
q,
k,
v,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@triton.jit
def _token_attn_1_kernel(
Q,
K,
q_input_scale,
k_input_scale,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
@triton.jit
def _token_attn_1_alibi_kernel(
Q,
K,
q_input_scale,
k_input_scale,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = max_kv_cache_len
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark)
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
return
@torch.no_grad()
def token_attn_fwd_1(
q,
k,
attn_out,
q_input_scale,
k_input_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None,
):
BLOCK = 32
# shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
assert q_head_dim == k_head_dim
assert k_head_dim in {16, 32, 64, 128}
sm_scale = 1.0 / (k_head_dim**0.5)
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
num_warps = 4 if k_head_dim <= 64 else 8
num_warps = 2
if alibi is not None:
_token_attn_1_alibi_kernel[grid](
q,
k,
q_input_scale,
k_input_scale,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_token_attn_1_kernel[grid](
q,
k,
q_input_scale,
k_input_scale,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
attn_out.stride(0),
attn_out.stride(1),
HEAD_DIM=k_head_dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@triton.jit
def _token_attn_softmax_fwd(
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
logics_head_dim_stride,
logics_batch_stride,
prob_head_dim_stride,
prob_batch_stride,
BLOCK_SIZE: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(
softmax_logics
+ current_head * logics_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(
softmax_prob_out
+ current_head * prob_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len,
)
return
@torch.no_grad()
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
_token_attn_softmax_fwd[(batch, head_num)](
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
softmax_logics.stride(0),
softmax_logics.stride(1),
softmax_prob_out.stride(0),
softmax_prob_out.stride(1),
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return
@triton.jit
def _token_attn_2_kernel(
Prob,
V,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
prob_head_dim_stride,
prob_batch_stride,
v_batch_stride,
v_head_stride,
v_head_dim_stride,
attn_out_batch_stride,
attn_out_head_stride,
attn_out_head_dim_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(
Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_loc = tl.load(
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0,
)
v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
off_o = (
current_batch * attn_out_batch_stride
+ current_head * attn_out_head_stride
+ offs_d * attn_out_head_dim_stride
)
out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc)
return
@torch.no_grad()
def token_attn_fwd_2(
prob,
v,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
):
if triton.__version__ >= "2.1.0":
BLOCK = 128
else:
BLOCK = 64
batch, head = kv_cache_loc.shape[0], v.shape[1]
grid = (batch, head)
num_warps = 4
dim = v.shape[-1]
_token_attn_2_kernel[grid](
prob,
v,
attn_out,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc.stride(0),
kv_cache_loc.stride(1),
prob.stride(0),
prob.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
attn_out.stride(0),
attn_out.stride(1),
attn_out.stride(2),
HEAD_DIM=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
@torch.no_grad()
def smooth_token_attention_fwd(
q,
k,
v,
attn_out,
q_input_scale,
k_input_scale,
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None,
):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
total_token_num = k.shape[0]
att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda")
token_attn_fwd_1(
q.view(calcu_shape1),
k,
att_m_tensor,
q_input_scale,
k_input_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi,
)
prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
token_attn_fwd_2(
prob,
v,
attn_out.view(calcu_shape1),
v_input_scale,
pv_output_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
)
prob = None
return

View File

@ -0,0 +1,69 @@
import argparse
import os
import torch
from datasets import load_dataset
from transformers import LlamaTokenizer
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
def build_model_and_tokenizer(model_name):
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512)
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs)
model = model.to(torch.float32)
return model, tokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, help="model name")
parser.add_argument(
"--output-path",
type=str,
help="where to save the checkpoint",
)
parser.add_argument(
"--dataset-path",
type=str,
help="location of the calibration dataset",
)
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
model_path = args.model_name
dataset_path = args.dataset_path
output_path = args.output_path
num_samples = 10
seq_len = 512
model, tokenizer = build_model_and_tokenizer(model_path)
if not os.path.exists(dataset_path):
print(f"Cannot find the dataset at {args.dataset_path}")
raise FileNotFoundError
dataset = load_dataset("json", data_files=dataset_path, split="train")
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
model = model.cuda()
model.save_quantized(output_path, model_basename="llama-7b")
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
model = model.cuda()
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
out = model.generate(**input_tokens, **generate_kwargs)
text = tokenizer.batch_decode(out)
print("out is:", text)
if __name__ == "__main__":
main()

52
op_builder/smoothquant.py Normal file
View File

@ -0,0 +1,52 @@
import torch
from .builder import Builder
from .utils import append_nvcc_threads
class SmoothquantBuilder(Builder):
NAME = "cu_smoothquant"
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"
def __init__(self):
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"smoothquant/binding.cpp",
"smoothquant/linear.cu",
]
]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
compute_capability = torch.cuda.get_device_capability()
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10
extra_cuda_flags = [
"-v",
f"-DCUDA_ARCH={cuda_arch}",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)
def builder(self):
try:
super().builder()
except:
warnings.warn("build smoothquant lib not successful")

View File

@ -0,0 +1,136 @@
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
import math
import torch
from torch.nn import functional as F
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
"""
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq = xq.view(bs, seqlen, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
mask[mask == 0.0] = -100000000.0
mask = mask.repeat(bs, num_head, 1, 1)
keys = xk
values = xv
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
)
def test_llama_context_attention():
head_num = 2
seq_len = 32
head_dim = 64
dtype = torch.float
hidden_size = head_num * head_dim
smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)
smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)
qkv_weight_scale = 1.0
ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")
smooth_attn = smooth_attn.to("cuda")
input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
input_scale = 1 / 20.0
output = torch.matmul(input.to(torch.float) * input_scale, ones)
qkv_max_out = torch.max(torch.abs(output)) / 127
smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
q = smooth_attn.q_proj(input)
k = smooth_attn.k_proj(input)
v = smooth_attn.v_proj(input)
cos_shape = (seq_len, head_dim // 2)
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
in_scale = torch.tensor([qkv_max_out], device="cuda")
out_scale = torch.tensor([qkv_max_out], device="cuda")
int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
q = q.to(torch.float) * out_scale
k = k.to(torch.float) * out_scale
v = v.to(torch.float) * out_scale
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
attn_out_max = torch.max(torch.abs(torch_out)) / 127
output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
smooth_attn.k_output_scale = torch.tensor(qkv_max_out)
smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)
smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
smooth_attn.out_proj.a = torch.tensor([attn_out_max])
torch_out = (
(torch_out / smooth_attn.attn_output_scale)
.round()
.clamp(-128, 127)
.to(torch.int8)
.view(-1, seq_len, head_num * head_dim)
)
torch_out = smooth_attn.out_proj(torch_out)
torch_out = torch_out.to(torch.float)
smooth_attn = smooth_attn.to("cuda")
smooth_out, _, _ = smooth_attn(input, (cos, sin))
smooth_out = smooth_out.to(torch.float)
assert torch.allclose(
torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
), "outputs from triton and torch are not matched"
if __name__ == "__main__":
test_llama_context_attention()

View File

@ -0,0 +1,84 @@
import warnings
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False
try:
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
HAS_TORCH_INT = True
except:
HAS_TORCH_INT = False
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
gate_out = torch.mm(x, gate_proj)
silu = torch.nn.SiLU()
gate_out = silu(gate_out)
up_out = torch.mm(x, up_proj)
o_out = gate_out * up_out
max_up = torch.max(torch.abs(o_out))
min_up = torch.min(torch.abs(o_out))
torch_out = torch.mm(o_out, down_proj)
return (torch_out, max_up, min_up)
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
reason="smoothquant linear not installed properly or not install torch_int",
)
def test_llama_mlp():
hidden_size = 256
intermediate_size = 512
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")
smooth_mlp.up_proj.weight = torch.randint(
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
)
smooth_mlp.down_proj.weight = torch.randint(
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
)
x = torch.ones((1, 256), dtype=torch.int8, device="cuda")
torch_out, max_inter, min_inter = torch_llama_mlp(
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
x.to(torch.float),
)
smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127)
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))
smooth_out = smooth_mlp(x)
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)
if __name__ == "__main__":
test_llama_mlp()

View File

@ -0,0 +1,39 @@
import warnings
import pytest
import torch
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False
@pytest.mark.skipif(
not HAS_SMOOTHQUANT_CUDA,
reason="smoothquant linear not installed properly",
)
def test_linear():
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
c = torch.rand(256, dtype=torch.float, device="cuda")
alpha = 1 / 127
beta = 1.0
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
silu = torch.nn.SiLU()
torch_out = silu(torch_out)
b = b.transpose(0, 1).contiguous()
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
if __name__ == "__main__":
test_linear()

View File

@ -0,0 +1,59 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import pytest
import torch
from packaging import version
try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.float
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)
input_scale = torch.max(torch.abs(x)) / 127
output_scale = torch.max(torch.abs(y_torch)) / 127
x = x / input_scale
x = x.to(torch.int8)
int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item())
y_triton = x.to(torch.float) * output_scale
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
if __name__ == "__main__":
test_rotary_emb()