mirror of https://github.com/hpcaitech/ColossalAI
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant licensepull/4918/head
parent
a0684e7bd6
commit
611a5a80ca
50
LICENSE
50
LICENSE
|
@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
---------------- LICENSE FOR torch-int ----------------
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 Guangxuan Xiao
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
---------------- LICENSE FOR smoothquant ----------------
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 MIT HAN Lab
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
try:
|
||||||
|
import torch_int
|
||||||
|
|
||||||
|
HAS_TORCH_INT = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH_INT = False
|
||||||
|
raise ImportError(
|
||||||
|
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
||||||
|
)
|
||||||
|
|
||||||
|
if HAS_TORCH_INT:
|
||||||
|
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
|
@ -0,0 +1,482 @@
|
||||||
|
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
|
||||||
|
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||||
|
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from abc import abstractmethod
|
||||||
|
from functools import partial
|
||||||
|
from os.path import isdir, isfile, join
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
from safetensors.torch import save_file as safe_save
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||||
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
from transformers.utils.generic import ContextManagers
|
||||||
|
from transformers.utils.hub import PushToHubMixin, cached_file
|
||||||
|
|
||||||
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
|
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = ["llama"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||||
|
layer_type: str = None
|
||||||
|
|
||||||
|
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model_type = self.model.config.model_type
|
||||||
|
self._quantized = quantized
|
||||||
|
self.config = self.model.config
|
||||||
|
self.cache_manager = None
|
||||||
|
self.max_total_token_num = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quantized(self):
|
||||||
|
return self._quantized
|
||||||
|
|
||||||
|
def init_cache_manager(self, max_total_token_num=2048):
|
||||||
|
if self.config.model_type == "llama":
|
||||||
|
head_num = self.config.num_key_value_heads
|
||||||
|
layer_num = self.config.num_hidden_layers
|
||||||
|
head_dim = self.config.hidden_size // head_num
|
||||||
|
|
||||||
|
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||||
|
self.max_total_token_num = max_total_token_num
|
||||||
|
|
||||||
|
def init_batch_state(self, max_output_len=256, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
batch_size = len(input_ids)
|
||||||
|
|
||||||
|
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
start_index = 0
|
||||||
|
max_len_in_batch = -1
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_len = len(input_ids[i])
|
||||||
|
seq_lengths[i] = seq_len
|
||||||
|
seq_start_indexes[i] = start_index
|
||||||
|
start_index += seq_len
|
||||||
|
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
|
||||||
|
|
||||||
|
if "max_total_token_num" in kwargs.keys():
|
||||||
|
max_total_token_num = kwargs["max_total_token_num"]
|
||||||
|
self.init_cache_manager(max_total_token_num)
|
||||||
|
|
||||||
|
if "max_new_tokens" in kwargs.keys():
|
||||||
|
max_output_len = kwargs["max_new_tokens"]
|
||||||
|
|
||||||
|
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
|
||||||
|
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
|
||||||
|
warnings.warn(f"reset max tokens to {max_total_token_num}")
|
||||||
|
self.init_cache_manager(max_total_token_num)
|
||||||
|
|
||||||
|
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
|
||||||
|
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||||
|
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||||
|
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
|
||||||
|
batch_infer_state.block_loc = block_loc
|
||||||
|
batch_infer_state.decode_layer_id = 0
|
||||||
|
batch_infer_state.past_key_values_len = 0
|
||||||
|
batch_infer_state.is_context_stage = True
|
||||||
|
batch_infer_state.set_cache_manager(self.cache_manager)
|
||||||
|
batch_infer_state.cache_manager.free_all()
|
||||||
|
return batch_infer_state
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def quantize(
|
||||||
|
self,
|
||||||
|
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
|
||||||
|
):
|
||||||
|
if self.quantized:
|
||||||
|
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
def generate(self, **kwargs):
|
||||||
|
"""shortcut for model.generate"""
|
||||||
|
|
||||||
|
batch_infer_state = self.init_batch_state(**kwargs)
|
||||||
|
if self.config.model_type == "llama":
|
||||||
|
setattr(self.model.model, "infer_state", batch_infer_state)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
return self.model.generate(**kwargs)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
|
"""shortcut for model.prepare_inputs_for_generation"""
|
||||||
|
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
||||||
|
|
||||||
|
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
|
||||||
|
for text in tqdm(dataset):
|
||||||
|
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||||
|
model(input_ids)
|
||||||
|
|
||||||
|
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
|
||||||
|
pbar = tqdm(dataset)
|
||||||
|
for text in pbar:
|
||||||
|
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||||
|
model(input_ids)
|
||||||
|
mean_scale = np.mean([v["input"] for v in act_dict.values()])
|
||||||
|
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
|
||||||
|
|
||||||
|
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
|
||||||
|
model.eval()
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
act_scales = {}
|
||||||
|
|
||||||
|
def stat_tensor(name, tensor):
|
||||||
|
hidden_dim = tensor.shape[-1]
|
||||||
|
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||||
|
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
|
||||||
|
if name in act_scales:
|
||||||
|
act_scales[name] = torch.max(act_scales[name], comming_max)
|
||||||
|
else:
|
||||||
|
act_scales[name] = comming_max
|
||||||
|
|
||||||
|
def stat_input_hook(m, x, y, name):
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[0]
|
||||||
|
stat_tensor(name, x)
|
||||||
|
|
||||||
|
hooks = []
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
|
||||||
|
|
||||||
|
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
|
||||||
|
|
||||||
|
for h in hooks:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return act_scales
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
|
||||||
|
if not isinstance(fcs, list):
|
||||||
|
fcs = [fcs]
|
||||||
|
for fc in fcs:
|
||||||
|
assert isinstance(fc, nn.Linear)
|
||||||
|
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||||
|
|
||||||
|
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||||
|
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||||
|
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||||
|
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||||
|
|
||||||
|
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||||
|
|
||||||
|
ln.weight.div_(scales)
|
||||||
|
if hasattr(ln, "bias"):
|
||||||
|
ln.bias.div_(scales)
|
||||||
|
|
||||||
|
for fc in fcs:
|
||||||
|
fc.weight.mul_(scales.view(1, -1))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_quantized_model(model):
|
||||||
|
raise NotImplementedError("Not implement create_quantized_model method")
|
||||||
|
|
||||||
|
def save_quantized(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
model_basename: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""save quantized model and configs to local disk"""
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not self.quantized:
|
||||||
|
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||||
|
|
||||||
|
self.model.to("cpu")
|
||||||
|
|
||||||
|
model_base_name = model_basename # or f"smooth-"
|
||||||
|
if use_safetensors:
|
||||||
|
model_save_name = model_base_name + ".safetensors"
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||||
|
if safetensors_metadata is None:
|
||||||
|
safetensors_metadata = {}
|
||||||
|
elif not isinstance(safetensors_metadata, dict):
|
||||||
|
raise TypeError("safetensors_metadata must be a dictionary.")
|
||||||
|
else:
|
||||||
|
print(f"Received safetensors_metadata: {safetensors_metadata}")
|
||||||
|
new_safetensors_metadata = {}
|
||||||
|
converted_keys = False
|
||||||
|
for key, value in safetensors_metadata.items():
|
||||||
|
if not isinstance(key, str) or not isinstance(value, str):
|
||||||
|
converted_keys = True
|
||||||
|
try:
|
||||||
|
new_key = str(key)
|
||||||
|
new_value = str(value)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
|
||||||
|
)
|
||||||
|
if new_key in new_safetensors_metadata:
|
||||||
|
print(
|
||||||
|
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
|
||||||
|
)
|
||||||
|
new_safetensors_metadata[new_key] = new_value
|
||||||
|
safetensors_metadata = new_safetensors_metadata
|
||||||
|
if converted_keys:
|
||||||
|
print(
|
||||||
|
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format is required to enable Accelerate to load the metadata
|
||||||
|
# otherwise it raises an OSError
|
||||||
|
safetensors_metadata["format"] = "pt"
|
||||||
|
|
||||||
|
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
||||||
|
else:
|
||||||
|
model_save_name = model_base_name + ".bin"
|
||||||
|
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||||
|
|
||||||
|
self.model.config.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""alias of save_quantized"""
|
||||||
|
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||||
|
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
max_memory: Optional[dict] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
torch_dtype: torch.dtype = torch.float16,
|
||||||
|
**model_init_kwargs,
|
||||||
|
):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
|
||||||
|
|
||||||
|
def skip(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
torch.nn.init.kaiming_uniform_ = skip
|
||||||
|
torch.nn.init.uniform_ = skip
|
||||||
|
torch.nn.init.normal_ = skip
|
||||||
|
|
||||||
|
# Parameters related to loading from Hugging Face Hub
|
||||||
|
cache_dir = model_init_kwargs.pop("cache_dir", None)
|
||||||
|
force_download = model_init_kwargs.pop("force_download", False)
|
||||||
|
resume_download = model_init_kwargs.pop("resume_download", False)
|
||||||
|
proxies = model_init_kwargs.pop("proxies", None)
|
||||||
|
local_files_only = model_init_kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
|
||||||
|
revision = model_init_kwargs.pop("revision", None)
|
||||||
|
subfolder = model_init_kwargs.pop("subfolder", "")
|
||||||
|
model_init_kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
|
cached_file_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"force_download": force_download,
|
||||||
|
"proxies": proxies,
|
||||||
|
"resume_download": resume_download,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
"revision": revision,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
|
||||||
|
if config.model_type not in SUPPORTED_MODELS:
|
||||||
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||||
|
|
||||||
|
# enforce some values despite user specified
|
||||||
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
model_init_kwargs["trust_remote_code"] = trust_remote_code
|
||||||
|
if max_memory:
|
||||||
|
if "disk" in max_memory:
|
||||||
|
raise NotImplementedError("disk offload not support yet.")
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
max_memory = accelerate.utils.get_balanced_memory(
|
||||||
|
model,
|
||||||
|
max_memory=max_memory,
|
||||||
|
no_split_module_classes=[cls.layer_type],
|
||||||
|
dtype=model_init_kwargs["torch_dtype"],
|
||||||
|
low_zero=False,
|
||||||
|
)
|
||||||
|
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
|
||||||
|
model,
|
||||||
|
max_memory=max_memory,
|
||||||
|
no_split_module_classes=[cls.layer_type],
|
||||||
|
dtype=model_init_kwargs["torch_dtype"],
|
||||||
|
)
|
||||||
|
model_init_kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
|
del model
|
||||||
|
else:
|
||||||
|
model_init_kwargs["device_map"] = None
|
||||||
|
model_init_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
|
||||||
|
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
|
if any([k in model_config for k in seq_len_keys]):
|
||||||
|
for key in seq_len_keys:
|
||||||
|
if key in model_config:
|
||||||
|
model.seqlen = model_config[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||||
|
model.seqlen = 4096
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return cls(model, False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_quantized(
|
||||||
|
cls,
|
||||||
|
model_name_or_path: Optional[str],
|
||||||
|
model_basename: Optional[str] = None,
|
||||||
|
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
||||||
|
max_memory: Optional[dict] = None,
|
||||||
|
device: Optional[Union[str, int]] = None,
|
||||||
|
low_cpu_mem_usage: bool = False,
|
||||||
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""load quantized model from local disk"""
|
||||||
|
|
||||||
|
# Parameters related to loading from Hugging Face Hub
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
|
cached_file_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"force_download": force_download,
|
||||||
|
"proxies": proxies,
|
||||||
|
"resume_download": resume_download,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
"revision": revision,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
|
"_commit_hash": commit_hash,
|
||||||
|
}
|
||||||
|
|
||||||
|
# == step1: prepare configs and file names == #
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.model_type not in SUPPORTED_MODELS:
|
||||||
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||||
|
|
||||||
|
extensions = []
|
||||||
|
if use_safetensors:
|
||||||
|
extensions.append(".safetensors")
|
||||||
|
else:
|
||||||
|
extensions += [".bin", ".pt"]
|
||||||
|
|
||||||
|
model_name_or_path = str(model_name_or_path)
|
||||||
|
is_local = isdir(model_name_or_path)
|
||||||
|
|
||||||
|
resolved_archive_file = None
|
||||||
|
if is_local:
|
||||||
|
model_save_name = join(model_name_or_path, model_basename)
|
||||||
|
for ext in extensions:
|
||||||
|
if isfile(model_save_name + ext):
|
||||||
|
resolved_archive_file = model_save_name + ext
|
||||||
|
break
|
||||||
|
else: # remote
|
||||||
|
for ext in extensions:
|
||||||
|
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if resolved_archive_file is None: # Could not find a model file to use
|
||||||
|
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||||
|
|
||||||
|
model_save_name = resolved_archive_file
|
||||||
|
|
||||||
|
# == step2: convert model to quantized-model (replace Linear) == #
|
||||||
|
def skip(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
torch.nn.init.kaiming_uniform_ = skip
|
||||||
|
torch.nn.init.uniform_ = skip
|
||||||
|
torch.nn.init.normal_ = skip
|
||||||
|
|
||||||
|
transformers.modeling_utils._init_weights = False
|
||||||
|
|
||||||
|
init_contexts = [no_init_weights()]
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
|
||||||
|
|
||||||
|
with ContextManagers(init_contexts):
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
cls.create_quantized_model(model)
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# == step3: load checkpoint to quantized-model == #
|
||||||
|
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||||
|
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# == step4: set seqlen == #
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
|
if any([k in model_config for k in seq_len_keys]):
|
||||||
|
for key in seq_len_keys:
|
||||||
|
if key in model_config:
|
||||||
|
model.seqlen = model_config[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||||
|
model.seqlen = 4096
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
try:
|
||||||
|
return super().__getattr__(item)
|
||||||
|
except:
|
||||||
|
return getattr(self.model, item)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BaseSmoothForCausalLM"]
|
|
@ -0,0 +1,177 @@
|
||||||
|
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||||
|
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||||
|
|
||||||
|
smoothquant_cuda = SmoothquantBuilder().load()
|
||||||
|
HAS_SMOOTHQUANT_CUDA = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SMOOTHQUANT_CUDA = False
|
||||||
|
raise ImportError("CUDA smoothquant linear is not installed")
|
||||||
|
|
||||||
|
|
||||||
|
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"weight",
|
||||||
|
torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(self.out_features, self.in_features),
|
||||||
|
dtype=torch.int8,
|
||||||
|
requires_grad=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"bias",
|
||||||
|
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
|
||||||
|
)
|
||||||
|
self.register_buffer("a", torch.tensor(alpha))
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
self.weight = self.weight.to(*args, **kwargs)
|
||||||
|
self.bias = self.bias.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.view(-1, x_shape[-1])
|
||||||
|
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
|
||||||
|
y = y.view(*x_shape[:-1], -1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_float(module: torch.nn.Linear, input_scale):
|
||||||
|
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
|
||||||
|
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||||
|
alpha = input_scale * weight_scale
|
||||||
|
int8_module.weight = int8_weight
|
||||||
|
if module.bias is not None:
|
||||||
|
int8_module.bias.data.copy_(module.bias.to(torch.float))
|
||||||
|
int8_module.a = alpha
|
||||||
|
return int8_module
|
||||||
|
|
||||||
|
|
||||||
|
class W8A8B8O8Linear(torch.nn.Module):
|
||||||
|
# For qkv_proj
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"weight",
|
||||||
|
torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(self.out_features, self.in_features),
|
||||||
|
dtype=torch.int8,
|
||||||
|
requires_grad=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"bias",
|
||||||
|
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
|
||||||
|
)
|
||||||
|
self.register_buffer("a", torch.tensor(alpha))
|
||||||
|
self.register_buffer("b", torch.tensor(beta))
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
self.weight = self.weight.to(*args, **kwargs)
|
||||||
|
self.bias = self.bias.to(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.view(-1, x_shape[-1])
|
||||||
|
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
|
||||||
|
y = y.view(*x_shape[:-1], -1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_float(module: torch.nn.Linear, input_scale, output_scale):
|
||||||
|
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
|
||||||
|
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||||
|
alpha = input_scale * weight_scale / output_scale
|
||||||
|
int8_module.weight = int8_weight
|
||||||
|
int8_module.a = alpha
|
||||||
|
|
||||||
|
if module.bias is not None:
|
||||||
|
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
|
||||||
|
int8_module.bias = int8_bias
|
||||||
|
beta = bias_scale / output_scale
|
||||||
|
int8_module.b = beta
|
||||||
|
|
||||||
|
return int8_module
|
||||||
|
|
||||||
|
|
||||||
|
class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||||
|
# For fc2 and out_proj
|
||||||
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"weight",
|
||||||
|
torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(self.out_features, self.in_features),
|
||||||
|
dtype=torch.int8,
|
||||||
|
requires_grad=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"bias",
|
||||||
|
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||||
|
)
|
||||||
|
self.register_buffer("a", torch.tensor(alpha))
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
# prevent the bias from being converted to half
|
||||||
|
super()._apply(fn)
|
||||||
|
self.bias = self.bias.to(torch.float32)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
super().to(*args, **kwargs)
|
||||||
|
self.weight = self.weight.to(*args, **kwargs)
|
||||||
|
self.bias = self.bias.to(*args, **kwargs)
|
||||||
|
self.bias = self.bias.to(torch.float32)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.view(-1, x_shape[-1])
|
||||||
|
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
|
||||||
|
y = y.view(*x_shape[:-1], -1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_float(module: torch.nn.Linear, input_scale):
|
||||||
|
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
|
||||||
|
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||||
|
alpha = input_scale * weight_scale
|
||||||
|
int8_module.weight = int8_weight
|
||||||
|
int8_module.a = alpha
|
||||||
|
int8_module.input_scale = input_scale
|
||||||
|
int8_module.weight_scale = weight_scale
|
||||||
|
|
||||||
|
if module.bias is not None:
|
||||||
|
int8_module.bias = module.bias.to(torch.float32)
|
||||||
|
|
||||||
|
return int8_module
|
|
@ -0,0 +1,846 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LLAMA_INPUTS_DOCSTRING,
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaMLP,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
repeat_kv,
|
||||||
|
rotate_half,
|
||||||
|
)
|
||||||
|
from transformers.utils import add_start_docstrings_to_model_forward
|
||||||
|
|
||||||
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
|
from colossalai.kernel.triton import (
|
||||||
|
copy_kv_cache_to_dest,
|
||||||
|
int8_rotary_embedding_fwd,
|
||||||
|
smooth_llama_context_attn_fwd,
|
||||||
|
smooth_token_attention_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base_model import BaseSmoothForCausalLM
|
||||||
|
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||||
|
|
||||||
|
|
||||||
|
class LLamaSmoothquantAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
|
||||||
|
if (self.head_dim * num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
|
||||||
|
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
|
||||||
|
|
||||||
|
self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||||
|
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||||
|
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||||
|
self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
|
||||||
|
|
||||||
|
self.register_buffer("q_output_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("k_output_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("v_output_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("out_input_scale", torch.tensor([1.0]))
|
||||||
|
self.register_buffer("attn_input_scale", torch.tensor([1.0]))
|
||||||
|
|
||||||
|
self._init_rope()
|
||||||
|
self.num_key_value_heads = num_heads
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pack(
|
||||||
|
module: LlamaAttention,
|
||||||
|
attn_input_scale: float,
|
||||||
|
q_output_scale: float,
|
||||||
|
k_output_scale: float,
|
||||||
|
v_output_scale: float,
|
||||||
|
q_rotary_output_scale: float,
|
||||||
|
k_rotary_output_scale: float,
|
||||||
|
out_input_scale: float,
|
||||||
|
):
|
||||||
|
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
|
||||||
|
|
||||||
|
int8_module.attn_input_scale = torch.tensor([attn_input_scale])
|
||||||
|
|
||||||
|
int8_module.q_output_scale = torch.tensor([q_output_scale])
|
||||||
|
int8_module.k_output_scale = torch.tensor([k_output_scale])
|
||||||
|
int8_module.v_output_scale = torch.tensor([v_output_scale])
|
||||||
|
|
||||||
|
int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
|
||||||
|
int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
|
||||||
|
|
||||||
|
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
|
||||||
|
int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
|
||||||
|
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
|
||||||
|
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
|
||||||
|
|
||||||
|
int8_module.out_input_scale = torch.tensor([out_input_scale])
|
||||||
|
|
||||||
|
return int8_module
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
rotary_emb: Tuple[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
infer_state: Optional[BatchInferState] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
cos = rotary_emb[0]
|
||||||
|
sin = rotary_emb[1]
|
||||||
|
|
||||||
|
int8_rotary_embedding_fwd(
|
||||||
|
query_states.view(-1, self.num_heads, self.head_dim),
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
self.q_output_scale.item(),
|
||||||
|
self.q_rotary_output_scale.item(),
|
||||||
|
)
|
||||||
|
int8_rotary_embedding_fwd(
|
||||||
|
key_states.view(-1, self.num_heads, self.head_dim),
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
self.k_output_scale.item(),
|
||||||
|
self.k_rotary_output_scale.item(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE might want to revise
|
||||||
|
# need some way to record the length of past key values cache
|
||||||
|
# since we won't return past_key_value_cache right now
|
||||||
|
if infer_state.decode_layer_id == 0: # once per model.forward
|
||||||
|
infer_state.cache_manager.past_key_values_length += q_len # seq_len
|
||||||
|
|
||||||
|
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||||
|
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||||
|
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||||
|
return
|
||||||
|
|
||||||
|
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if infer_state.is_context_stage:
|
||||||
|
# first token generation
|
||||||
|
|
||||||
|
# copy key and value calculated in current step to memory manager
|
||||||
|
_copy_kv_to_mem_cache(
|
||||||
|
infer_state.decode_layer_id,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
infer_state.context_mem_index,
|
||||||
|
infer_state.cache_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = torch.empty_like(query_states)
|
||||||
|
|
||||||
|
smooth_llama_context_attn_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_output,
|
||||||
|
self.q_rotary_output_scale.item(),
|
||||||
|
self.k_rotary_output_scale.item(),
|
||||||
|
self.v_output_scale.item(),
|
||||||
|
self.out_input_scale.item(),
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
q_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if infer_state.decode_is_contiguous:
|
||||||
|
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||||
|
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||||
|
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||||
|
]
|
||||||
|
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||||
|
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||||
|
]
|
||||||
|
cache_k.copy_(key_states)
|
||||||
|
cache_v.copy_(value_states)
|
||||||
|
else:
|
||||||
|
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||||
|
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||||
|
_copy_kv_to_mem_cache(
|
||||||
|
infer_state.decode_layer_id,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
infer_state.decode_mem_index,
|
||||||
|
infer_state.cache_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# (batch_size, seqlen, nheads, headdim)
|
||||||
|
attn_output = torch.empty_like(query_states)
|
||||||
|
|
||||||
|
smooth_token_attention_fwd(
|
||||||
|
query_states,
|
||||||
|
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||||
|
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||||
|
attn_output,
|
||||||
|
self.q_rotary_output_scale.item(),
|
||||||
|
self.k_rotary_output_scale.item(),
|
||||||
|
self.v_output_scale.item(),
|
||||||
|
self.out_input_scale.item(),
|
||||||
|
infer_state.block_loc,
|
||||||
|
infer_state.start_loc,
|
||||||
|
infer_state.seq_len,
|
||||||
|
infer_state.cache_manager.past_key_values_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaLayerNormQ(torch.nn.Module):
|
||||||
|
def __init__(self, dim, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.input_scale = 1.0
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
|
||||||
|
ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
|
||||||
|
return ln_output_int8
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_float(module: torch.nn.LayerNorm, output_scale: float):
|
||||||
|
assert module.weight.shape[0] == module.weight.numel()
|
||||||
|
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
|
||||||
|
q_module.weight = module.weight / output_scale
|
||||||
|
return q_module
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaSmoothquantMLP(nn.Module):
|
||||||
|
def __init__(self, intermediate_size, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
|
||||||
|
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
|
||||||
|
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
|
||||||
|
self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pack(
|
||||||
|
mlp_module: LlamaMLP,
|
||||||
|
gate_proj_input_scale: float,
|
||||||
|
up_proj_input_scale: float,
|
||||||
|
down_proj_input_scale: float,
|
||||||
|
):
|
||||||
|
int8_module = LlamaSmoothquantMLP(
|
||||||
|
mlp_module.intermediate_size,
|
||||||
|
mlp_module.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
|
||||||
|
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
|
||||||
|
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
|
||||||
|
int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
|
||||||
|
return int8_module
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
):
|
||||||
|
x_shape = hidden_states.shape
|
||||||
|
gate_out = self.gate_proj(hidden_states)
|
||||||
|
up_out = self.up_proj(hidden_states)
|
||||||
|
inter_out = gate_out * up_out
|
||||||
|
inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
|
||||||
|
down_out = self.down_proj(inter_out)
|
||||||
|
down_out = down_out.view(*x_shape[:-1], -1)
|
||||||
|
return down_out
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
|
||||||
|
|
||||||
|
self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
|
||||||
|
self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pack(
|
||||||
|
module: LlamaDecoderLayer,
|
||||||
|
attn_input_scale: float,
|
||||||
|
q_output_scale: float,
|
||||||
|
k_output_scale: float,
|
||||||
|
v_output_scale: float,
|
||||||
|
q_rotary_output_scale: float,
|
||||||
|
k_rotary_output_scale: float,
|
||||||
|
out_input_scale: float,
|
||||||
|
gate_input_scale: float,
|
||||||
|
up_input_scale: float,
|
||||||
|
down_input_scale: float,
|
||||||
|
):
|
||||||
|
config = module.self_attn.config
|
||||||
|
int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
|
||||||
|
|
||||||
|
int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
|
||||||
|
int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
|
||||||
|
module.self_attn,
|
||||||
|
attn_input_scale,
|
||||||
|
q_output_scale,
|
||||||
|
k_output_scale,
|
||||||
|
v_output_scale,
|
||||||
|
q_rotary_output_scale,
|
||||||
|
k_rotary_output_scale,
|
||||||
|
out_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
|
||||||
|
module.post_attention_layernorm, gate_input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
|
||||||
|
module.mlp,
|
||||||
|
gate_input_scale,
|
||||||
|
up_input_scale,
|
||||||
|
down_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return int8_decoder_layer
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
rotary_emb: Tuple[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
infer_state: Optional[BatchInferState] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
rotary_emb=rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
infer_state=infer_state,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaApplyRotary(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, cos, sin, position_ids):
|
||||||
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||||
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||||
|
|
||||||
|
return x_embed
|
||||||
|
|
||||||
|
|
||||||
|
def llama_decoder_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
|
||||||
|
key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def init_to_get_rotary(config, base=10000, use_elem=False):
|
||||||
|
"""
|
||||||
|
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||||
|
Args:
|
||||||
|
base : calculation arg
|
||||||
|
use_elem : activated when using chatglm-based models
|
||||||
|
"""
|
||||||
|
config.head_dim_ = config.hidden_size // config.num_attention_heads
|
||||||
|
if not hasattr(config, "rope_scaling"):
|
||||||
|
rope_scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
|
||||||
|
|
||||||
|
if hasattr(config, "max_sequence_length"):
|
||||||
|
max_seq_len = config.max_sequence_length
|
||||||
|
elif hasattr(config, "max_position_embeddings"):
|
||||||
|
max_seq_len = config.max_position_embeddings * rope_scaling_factor
|
||||||
|
else:
|
||||||
|
max_seq_len = 2048 * rope_scaling_factor
|
||||||
|
base = float(base)
|
||||||
|
|
||||||
|
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
try:
|
||||||
|
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
|
||||||
|
assert ntk_alpha >= 1
|
||||||
|
if ntk_alpha > 1:
|
||||||
|
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||||
|
max_seq_len *= ntk_alpha
|
||||||
|
base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
n_elem = config.head_dim_
|
||||||
|
if use_elem:
|
||||||
|
n_elem //= 2
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||||
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||||
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
|
||||||
|
_cos_cached = torch.cos(freqs).to(torch.float)
|
||||||
|
_sin_cached = torch.sin(freqs).to(torch.float)
|
||||||
|
return _cos_cached, _sin_cached
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
def llama_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
infer_state = self.infer_state
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
# NOT READY FOR PRIME TIME
|
||||||
|
# dummy but work, revise it
|
||||||
|
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||||
|
# past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
# NOTE: differentiate with prefill stage
|
||||||
|
# block_loc require different value-assigning method for two different stage
|
||||||
|
# NOTE: differentiate with prefill stage
|
||||||
|
# block_loc require different value-assigning method for two different stage
|
||||||
|
if infer_state.is_context_stage:
|
||||||
|
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||||
|
infer_state.init_block_loc(
|
||||||
|
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||||
|
if alloc_mem is not None:
|
||||||
|
infer_state.decode_is_contiguous = True
|
||||||
|
infer_state.decode_mem_index = alloc_mem[0]
|
||||||
|
infer_state.decode_mem_start = alloc_mem[1]
|
||||||
|
infer_state.decode_mem_end = alloc_mem[2]
|
||||||
|
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||||
|
else:
|
||||||
|
print(f" *** Encountered allocation non-contiguous")
|
||||||
|
print(
|
||||||
|
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
|
||||||
|
)
|
||||||
|
infer_state.decode_is_contiguous = False
|
||||||
|
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||||
|
infer_state.decode_mem_index = alloc_mem
|
||||||
|
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||||
|
|
||||||
|
if past_key_values_length == 0:
|
||||||
|
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||||
|
position_ids.view(-1).shape[0], -1
|
||||||
|
)
|
||||||
|
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||||
|
position_ids.view(-1).shape[0], -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||||
|
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
infer_state.decode_layer_id = 0
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
rotary_emb=(position_cos, position_sin),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
infer_state=infer_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
infer_state.decode_layer_id += 1
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
infer_state.is_context_stage = False
|
||||||
|
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
infer_state.seq_len += 1
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
|
||||||
|
layer_type = "LlamaDecoderLayer"
|
||||||
|
|
||||||
|
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||||
|
super().__init__(model, quantized)
|
||||||
|
|
||||||
|
def get_act_dict(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
dataset,
|
||||||
|
num_samples=512,
|
||||||
|
seq_len=512,
|
||||||
|
):
|
||||||
|
llama_model = self.model
|
||||||
|
|
||||||
|
llama_model.eval()
|
||||||
|
device = next(llama_model.parameters()).device
|
||||||
|
# print("model:", llama_model)
|
||||||
|
act_dict = defaultdict(dict)
|
||||||
|
|
||||||
|
def stat_io_hook(m, x, y, name):
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x = x[0]
|
||||||
|
if name not in act_dict or "input" not in act_dict[name]:
|
||||||
|
act_dict[name]["input"] = x.detach().abs().max().item()
|
||||||
|
else:
|
||||||
|
act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
|
||||||
|
if isinstance(y, tuple):
|
||||||
|
y = y[0]
|
||||||
|
if name not in act_dict or "output" not in act_dict[name]:
|
||||||
|
act_dict[name]["output"] = y.detach().abs().max().item()
|
||||||
|
else:
|
||||||
|
act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
|
||||||
|
|
||||||
|
for name, m in llama_model.named_modules():
|
||||||
|
if isinstance(m, LlamaAttention):
|
||||||
|
setattr(m, "q_apply_rotary", LlamaApplyRotary())
|
||||||
|
setattr(m, "k_apply_rotary", LlamaApplyRotary())
|
||||||
|
m.forward = types.MethodType(llama_decoder_layer_forward, m)
|
||||||
|
|
||||||
|
hooks = []
|
||||||
|
for name, m in llama_model.named_modules():
|
||||||
|
if isinstance(m, LlamaApplyRotary):
|
||||||
|
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||||
|
|
||||||
|
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
return act_dict
|
||||||
|
|
||||||
|
def smooth_fn(self, scales, alpha=0.5):
|
||||||
|
model = self.model
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaDecoderLayer):
|
||||||
|
attn_ln = module.input_layernorm
|
||||||
|
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
|
||||||
|
qkv_input_scales = scales[name + ".self_attn.q_proj"]
|
||||||
|
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
|
||||||
|
|
||||||
|
def create_quantized_model(model):
|
||||||
|
llama_config = model.config
|
||||||
|
for i, layer in enumerate(model.model.layers):
|
||||||
|
model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
|
||||||
|
|
||||||
|
model.model.forward = types.MethodType(llama_model_forward, model.model)
|
||||||
|
cos, sin = init_to_get_rotary(llama_config)
|
||||||
|
model.model.register_buffer("_cos_cached", cos)
|
||||||
|
model.model.register_buffer("_sin_cached", sin)
|
||||||
|
|
||||||
|
def quantized(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
dataset,
|
||||||
|
num_samples=512,
|
||||||
|
seq_len=512,
|
||||||
|
alpha=0.5,
|
||||||
|
):
|
||||||
|
llama_model = self.model
|
||||||
|
llama_config = llama_model.config
|
||||||
|
|
||||||
|
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
|
||||||
|
|
||||||
|
self.smooth_fn(act_scales, alpha)
|
||||||
|
|
||||||
|
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
|
||||||
|
decoder_layer_scales = []
|
||||||
|
|
||||||
|
for idx in range(llama_config.num_hidden_layers):
|
||||||
|
scale_dict = {}
|
||||||
|
scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
|
||||||
|
scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
|
||||||
|
scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
|
||||||
|
scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
|
||||||
|
|
||||||
|
scale_dict["q_rotary_output_scale"] = (
|
||||||
|
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
|
||||||
|
)
|
||||||
|
scale_dict["k_rotary_output_scale"] = (
|
||||||
|
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
|
||||||
|
|
||||||
|
scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
|
||||||
|
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
|
||||||
|
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
|
||||||
|
|
||||||
|
decoder_layer_scales.append(scale_dict)
|
||||||
|
|
||||||
|
for i, layer in enumerate(llama_model.model.layers):
|
||||||
|
orig_layer = layer
|
||||||
|
llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
|
||||||
|
|
||||||
|
llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
|
||||||
|
|
||||||
|
cos, sin = init_to_get_rotary(llama_config)
|
||||||
|
llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
|
||||||
|
llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))
|
|
@ -0,0 +1,8 @@
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "linear.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
|
||||||
|
"Linear SiLU (INT8)");
|
||||||
|
}
|
|
@ -0,0 +1,162 @@
|
||||||
|
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
|
||||||
|
|
||||||
|
#include "linear.h"
|
||||||
|
#include <cutlass/core_io.h>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/half.h>
|
||||||
|
|
||||||
|
#include <cutlass/gemm/device/gemm.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
#include <cutlass/util/host_tensor.h>
|
||||||
|
#include <cutlass/epilogue/thread/linear_combination_silu.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||||
|
torch::Tensor weight, // INT8
|
||||||
|
torch::Tensor bias, // FP32
|
||||||
|
float alpha, // FP32
|
||||||
|
float beta // FP32
|
||||||
|
) {
|
||||||
|
auto M = input.size(0);
|
||||||
|
auto N = weight.size(0);
|
||||||
|
auto K = input.size(1);
|
||||||
|
|
||||||
|
using ElementOutput = float;
|
||||||
|
using ElementAccumulator = int32_t;
|
||||||
|
using ElementComputeEpilogue = float;
|
||||||
|
using ElementInputA = int8_t; // <- data type of elements in input matrix A
|
||||||
|
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||||
|
|
||||||
|
// The code section below describes matrix layout of input and output
|
||||||
|
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
|
||||||
|
// for Matrix C
|
||||||
|
using LayoutInputA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutOutput = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
#if CUDA_ARCH >= 800
|
||||||
|
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||||
|
ElementOutput, // <- data type of output matrix
|
||||||
|
128 / cutlass::sizeof_bits<
|
||||||
|
ElementOutput>::value, // <- this is the number of elements per
|
||||||
|
// vectorized memory access. For half
|
||||||
|
// precision, it's 8 elements. This
|
||||||
|
// becomes the vector width of math
|
||||||
|
// instructions in epilogue too
|
||||||
|
ElementAccumulator, // <- data type of accumulator
|
||||||
|
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||||
|
// function
|
||||||
|
>;
|
||||||
|
using Gemm = cutlass::gemm::device::Gemm<
|
||||||
|
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||||
|
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||||
|
EpilogueOp,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||||
|
#elif CUDA_ARCH >= 750
|
||||||
|
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||||
|
ElementOutput, // <- data type of output matrix
|
||||||
|
128 / cutlass::sizeof_bits<
|
||||||
|
ElementOutput>::value, // <- this is the number of elements per
|
||||||
|
// vectorized memory access. For half
|
||||||
|
// precision, it's 8 elements. This
|
||||||
|
// becomes the vector width of math
|
||||||
|
// instructions in epilogue too
|
||||||
|
ElementAccumulator, // <- data type of accumulator
|
||||||
|
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||||
|
// function
|
||||||
|
>;
|
||||||
|
|
||||||
|
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||||
|
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||||
|
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||||
|
using Gemm = cutlass::gemm::device::Gemm<
|
||||||
|
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||||
|
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||||
|
DefaultGemmCfg::InstructionShape,
|
||||||
|
EpilogueOp>;
|
||||||
|
#elif CUDA_ARCH >= 700
|
||||||
|
#define USE_TORCH_SILU
|
||||||
|
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||||
|
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||||
|
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||||
|
using Gemm = cutlass::gemm::device::Gemm<
|
||||||
|
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||||
|
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||||
|
DefaultGemmCfg::InstructionShape,
|
||||||
|
cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
|
||||||
|
#else
|
||||||
|
#error "Unsupported cuda arch"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto input_size = cutlass::MatrixCoord(M, K);
|
||||||
|
auto weight_size = cutlass::MatrixCoord(K, N);
|
||||||
|
auto output_size = cutlass::MatrixCoord(M, N);
|
||||||
|
|
||||||
|
auto device = input.device();
|
||||||
|
// use the broadcasted bias as the output
|
||||||
|
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
|
||||||
|
|
||||||
|
// constexpr int kSparse = Gemm::kSparse;
|
||||||
|
// How many elements of A are covered per ElementE
|
||||||
|
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||||
|
// The size of individual meta data
|
||||||
|
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||||
|
cutlass::gemm::GemmCoord problem_size(M, N, K);
|
||||||
|
|
||||||
|
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
|
||||||
|
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
|
||||||
|
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
|
||||||
|
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
|
||||||
|
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
|
||||||
|
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
|
||||||
|
|
||||||
|
typename Gemm::Arguments arguments{
|
||||||
|
problem_size, // <- problem size of matrix multiplication
|
||||||
|
input_ref, // <- reference to matrix A on device
|
||||||
|
weight_ref, // <- reference to matrix B on device
|
||||||
|
out_ref, // <- reference to matrix C on device
|
||||||
|
out_ref, // <- reference to matrix D on device
|
||||||
|
{alpha, beta}, 1};
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
// Using the arguments, query for extra workspace required for matrix
|
||||||
|
// multiplication computation
|
||||||
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||||
|
|
||||||
|
// Allocate workspace memory
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
// Check the problem size is supported or not
|
||||||
|
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
throw std::runtime_error("cutlass cannot implement");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||||
|
status = gemm_op.initialize(arguments, workspace.get());
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
throw std::runtime_error("cutlass cannot initialize");
|
||||||
|
}
|
||||||
|
|
||||||
|
status = gemm_op();
|
||||||
|
if (status != cutlass::Status::kSuccess) {
|
||||||
|
throw std::runtime_error("cutlass cannot run");
|
||||||
|
}
|
||||||
|
#ifdef USE_TORCH_SILU
|
||||||
|
#undef USE_TORCH_SILU
|
||||||
|
out = torch::silu(out);
|
||||||
|
#endif
|
||||||
|
return out;
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||||
|
torch::Tensor weight, // INT8
|
||||||
|
torch::Tensor bias, // FP32
|
||||||
|
float alpha, // FP32
|
||||||
|
float beta // FP32
|
||||||
|
);
|
|
@ -13,8 +13,10 @@ if HAS_TRITON:
|
||||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
|
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
|
||||||
from .rms_norm import rmsnorm_forward
|
from .rms_norm import rmsnorm_forward
|
||||||
from .rotary_embedding_kernel import rotary_embedding_fwd
|
from .rotary_embedding_kernel import rotary_embedding_fwd
|
||||||
|
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
from .token_attention_kernel import token_attention_fwd
|
from .token_attention_kernel import token_attention_fwd
|
||||||
|
|
||||||
|
@ -29,4 +31,7 @@ if HAS_TRITON:
|
||||||
"rotary_embedding_fwd",
|
"rotary_embedding_fwd",
|
||||||
"token_attention_fwd",
|
"token_attention_fwd",
|
||||||
"gptq_fused_linear_triton",
|
"gptq_fused_linear_triton",
|
||||||
|
"int8_rotary_embedding_fwd",
|
||||||
|
"smooth_llama_context_attn_fwd",
|
||||||
|
"smooth_token_attention_fwd",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rotary_kernel(
|
||||||
|
q,
|
||||||
|
input_scale,
|
||||||
|
output_scale,
|
||||||
|
Cos,
|
||||||
|
Sin,
|
||||||
|
q_bs_stride,
|
||||||
|
q_h_stride,
|
||||||
|
q_d_stride,
|
||||||
|
cos_bs_stride,
|
||||||
|
cos_d_stride,
|
||||||
|
total_len,
|
||||||
|
HEAD_NUM: tl.constexpr,
|
||||||
|
BLOCK_HEAD: tl.constexpr,
|
||||||
|
BLOCK_SEQ: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
):
|
||||||
|
current_head_index = tl.program_id(0)
|
||||||
|
current_seq_index = tl.program_id(1)
|
||||||
|
|
||||||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||||
|
|
||||||
|
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||||
|
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||||
|
|
||||||
|
off_q0 = (
|
||||||
|
current_seq_range[:, None, None] * q_bs_stride
|
||||||
|
+ current_head_range[None, :, None] * q_h_stride
|
||||||
|
+ dim_range0[None, None, :] * q_d_stride
|
||||||
|
)
|
||||||
|
off_q1 = (
|
||||||
|
current_seq_range[:, None, None] * q_bs_stride
|
||||||
|
+ current_head_range[None, :, None] * q_h_stride
|
||||||
|
+ dim_range1[None, None, :] * q_d_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
||||||
|
|
||||||
|
q0 = tl.load(
|
||||||
|
q + off_q0,
|
||||||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
q1 = tl.load(
|
||||||
|
q + off_q1,
|
||||||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||||
|
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||||
|
|
||||||
|
q0 = q0.to(tl.float32) * input_scale
|
||||||
|
q1 = q1.to(tl.float32) * input_scale
|
||||||
|
|
||||||
|
out0 = (q0 * cos - q1 * sin) / output_scale
|
||||||
|
out1 = (q0 * sin + q1 * cos) / output_scale
|
||||||
|
|
||||||
|
out0 = out0.to(tl.int8)
|
||||||
|
out1 = out1.to(tl.int8)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
q + off_q0,
|
||||||
|
out0,
|
||||||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
q + off_q1,
|
||||||
|
out1,
|
||||||
|
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
|
||||||
|
total_len = q.shape[0]
|
||||||
|
head_num = q.shape[1]
|
||||||
|
head_dim = q.shape[2]
|
||||||
|
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||||
|
BLOCK_HEAD = 4
|
||||||
|
BLOCK_SEQ = 32
|
||||||
|
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
||||||
|
if head_dim >= 128:
|
||||||
|
num_warps = 8
|
||||||
|
else:
|
||||||
|
num_warps = 4
|
||||||
|
|
||||||
|
_rotary_kernel[grid](
|
||||||
|
q,
|
||||||
|
input_scale,
|
||||||
|
output_scale,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
cos.stride(0),
|
||||||
|
cos.stride(1),
|
||||||
|
total_len,
|
||||||
|
HEAD_NUM=head_num,
|
||||||
|
BLOCK_HEAD=BLOCK_HEAD,
|
||||||
|
BLOCK_SEQ=BLOCK_SEQ,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
|
@ -0,0 +1,652 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
HAS_TRITON = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
"""
|
||||||
|
this function is modified from
|
||||||
|
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||||
|
"""
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _context_flash_attention_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
sm_scale,
|
||||||
|
B_Start_Loc,
|
||||||
|
B_Seqlen,
|
||||||
|
TMP,
|
||||||
|
alibi_ptr,
|
||||||
|
Out,
|
||||||
|
stride_qbs,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kbs,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vbs,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_obs,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
stride_tmp_b,
|
||||||
|
stride_tmp_h,
|
||||||
|
stride_tmp_s,
|
||||||
|
# suggtest set-up 64, 128, 256, 512
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
batch_id = tl.program_id(0)
|
||||||
|
cur_head = tl.program_id(1)
|
||||||
|
start_m = tl.program_id(2)
|
||||||
|
|
||||||
|
# initialize offsets
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
|
||||||
|
# get batch info
|
||||||
|
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||||
|
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||||
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
|
load_p_ptrs = (
|
||||||
|
Q
|
||||||
|
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_d[None, :] * stride_qd
|
||||||
|
)
|
||||||
|
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||||
|
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||||
|
|
||||||
|
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||||
|
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||||
|
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||||
|
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
|
||||||
|
if alibi_ptr is not None:
|
||||||
|
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||||
|
|
||||||
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||||
|
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if alibi_ptr is not None:
|
||||||
|
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||||
|
qk -= alibi_loc * alibi_m
|
||||||
|
|
||||||
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
|
|
||||||
|
m_ij = tl.max(qk, 1)
|
||||||
|
p = tl.exp(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
m_i_new = tl.maximum(m_i, m_ij)
|
||||||
|
alpha = tl.exp(m_i - m_i_new)
|
||||||
|
beta = tl.exp(m_ij - m_i_new)
|
||||||
|
l_i_new = alpha * l_i + beta * l_ij
|
||||||
|
# -- update output accumulator --
|
||||||
|
# scale p
|
||||||
|
p_scale = beta / l_i_new
|
||||||
|
p = p * p_scale[:, None]
|
||||||
|
# scale acc
|
||||||
|
acc_scale = l_i / l_i_new * alpha
|
||||||
|
tl.store(t_ptrs, acc_scale)
|
||||||
|
acc_scale = tl.load(t_ptrs)
|
||||||
|
acc = acc * acc_scale[:, None]
|
||||||
|
# update acc
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||||
|
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
v = v.to(tl.float16) * v_input_scale.to(tl.float16)
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
# update m_i and l_i
|
||||||
|
l_i = l_i_new
|
||||||
|
m_i = m_i_new
|
||||||
|
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
|
||||||
|
off_o = (
|
||||||
|
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||||
|
)
|
||||||
|
out_ptrs = Out + off_o
|
||||||
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def smooth_llama_context_attn_fwd(
|
||||||
|
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
|
||||||
|
):
|
||||||
|
|
||||||
|
BLOCK = 128
|
||||||
|
# shape constraints
|
||||||
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
|
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||||
|
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||||
|
assert Lk in {16, 32, 64, 128}
|
||||||
|
BLOCK_N = 128
|
||||||
|
sm_scale = 1.0 / math.sqrt(Lk)
|
||||||
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||||
|
|
||||||
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||||
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
|
_context_flash_attention_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
sm_scale,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
tmp,
|
||||||
|
None,
|
||||||
|
o,
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
o.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(2),
|
||||||
|
tmp.stride(0),
|
||||||
|
tmp.stride(1),
|
||||||
|
tmp.stride(2),
|
||||||
|
BLOCK_M=BLOCK,
|
||||||
|
BLOCK_DMODEL=Lk,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _token_attn_1_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
sm_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
attn_out,
|
||||||
|
kv_cache_loc_b_stride,
|
||||||
|
kv_cache_loc_s_stride,
|
||||||
|
q_batch_stride,
|
||||||
|
q_head_stride,
|
||||||
|
q_head_dim_stride,
|
||||||
|
k_batch_stride,
|
||||||
|
k_head_stride,
|
||||||
|
k_head_dim_stride,
|
||||||
|
attn_head_stride,
|
||||||
|
attn_batch_stride,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
current_batch = tl.program_id(0)
|
||||||
|
current_head = tl.program_id(1)
|
||||||
|
start_n = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, HEAD_DIM)
|
||||||
|
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||||
|
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||||
|
|
||||||
|
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||||
|
current_batch_end_index = max_kv_cache_len
|
||||||
|
|
||||||
|
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||||
|
|
||||||
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
block_stard_index = start_n * BLOCK_N
|
||||||
|
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
|
for start_mark in range(0, block_mask, 1):
|
||||||
|
q = tl.load(Q + off_q + start_mark)
|
||||||
|
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||||
|
offs_n_new = current_batch_start_index + offs_n
|
||||||
|
k_loc = tl.load(
|
||||||
|
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||||
|
mask=offs_n_new < current_batch_end_index,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||||
|
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||||
|
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||||
|
att_value = tl.sum(q[None, :] * k, 1)
|
||||||
|
att_value *= sm_scale
|
||||||
|
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||||
|
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||||
|
return
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _token_attn_1_alibi_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
sm_scale,
|
||||||
|
alibi,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
attn_out,
|
||||||
|
kv_cache_loc_b_stride,
|
||||||
|
kv_cache_loc_s_stride,
|
||||||
|
q_batch_stride,
|
||||||
|
q_head_stride,
|
||||||
|
q_head_dim_stride,
|
||||||
|
k_batch_stride,
|
||||||
|
k_head_stride,
|
||||||
|
k_head_dim_stride,
|
||||||
|
attn_head_stride,
|
||||||
|
attn_batch_stride,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
current_batch = tl.program_id(0)
|
||||||
|
current_head = tl.program_id(1)
|
||||||
|
start_n = tl.program_id(2)
|
||||||
|
|
||||||
|
offs_d = tl.arange(0, HEAD_DIM)
|
||||||
|
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||||
|
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||||
|
|
||||||
|
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||||
|
current_batch_end_index = max_kv_cache_len
|
||||||
|
|
||||||
|
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||||
|
|
||||||
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
block_stard_index = start_n * BLOCK_N
|
||||||
|
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
|
for start_mark in range(0, block_mask, 1):
|
||||||
|
alibi_m = tl.load(alibi + current_head)
|
||||||
|
q = tl.load(Q + off_q + start_mark)
|
||||||
|
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||||
|
|
||||||
|
offs_n_new = current_batch_start_index + offs_n
|
||||||
|
k_loc = tl.load(
|
||||||
|
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||||
|
mask=offs_n_new < current_batch_end_index,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||||
|
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||||
|
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||||
|
att_value = tl.sum(q[None, :] * k, 1)
|
||||||
|
att_value *= sm_scale
|
||||||
|
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
|
||||||
|
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||||
|
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||||
|
return
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def token_attn_fwd_1(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
attn_out,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
alibi=None,
|
||||||
|
):
|
||||||
|
BLOCK = 32
|
||||||
|
# shape constraints
|
||||||
|
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
|
||||||
|
assert q_head_dim == k_head_dim
|
||||||
|
assert k_head_dim in {16, 32, 64, 128}
|
||||||
|
sm_scale = 1.0 / (k_head_dim**0.5)
|
||||||
|
|
||||||
|
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
|
||||||
|
|
||||||
|
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
|
||||||
|
|
||||||
|
num_warps = 4 if k_head_dim <= 64 else 8
|
||||||
|
num_warps = 2
|
||||||
|
|
||||||
|
if alibi is not None:
|
||||||
|
_token_attn_1_alibi_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
sm_scale,
|
||||||
|
alibi,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
attn_out,
|
||||||
|
kv_cache_loc.stride(0),
|
||||||
|
kv_cache_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
attn_out.stride(0),
|
||||||
|
attn_out.stride(1),
|
||||||
|
HEAD_DIM=k_head_dim,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_token_attn_1_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
sm_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
attn_out,
|
||||||
|
kv_cache_loc.stride(0),
|
||||||
|
kv_cache_loc.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(2),
|
||||||
|
k.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(2),
|
||||||
|
attn_out.stride(0),
|
||||||
|
attn_out.stride(1),
|
||||||
|
HEAD_DIM=k_head_dim,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _token_attn_softmax_fwd(
|
||||||
|
softmax_logics,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
softmax_prob_out,
|
||||||
|
logics_head_dim_stride,
|
||||||
|
logics_batch_stride,
|
||||||
|
prob_head_dim_stride,
|
||||||
|
prob_batch_stride,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
current_batch = tl.program_id(0)
|
||||||
|
current_head = tl.program_id(1)
|
||||||
|
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||||
|
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||||
|
|
||||||
|
row = tl.load(
|
||||||
|
softmax_logics
|
||||||
|
+ current_head * logics_head_dim_stride
|
||||||
|
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
||||||
|
mask=col_offsets < current_batch_seq_len,
|
||||||
|
other=-float("inf"),
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
row_minus_max = row - tl.max(row, axis=0)
|
||||||
|
numerator = tl.exp(row_minus_max)
|
||||||
|
denominator = tl.sum(numerator, axis=0)
|
||||||
|
softmax_output = numerator / denominator
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
softmax_prob_out
|
||||||
|
+ current_head * prob_head_dim_stride
|
||||||
|
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
||||||
|
softmax_output,
|
||||||
|
mask=col_offsets < current_batch_seq_len,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
|
||||||
|
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
|
||||||
|
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
|
||||||
|
|
||||||
|
num_warps = 4
|
||||||
|
if BLOCK_SIZE >= 2048:
|
||||||
|
num_warps = 8
|
||||||
|
if BLOCK_SIZE >= 4096:
|
||||||
|
num_warps = 16
|
||||||
|
|
||||||
|
_token_attn_softmax_fwd[(batch, head_num)](
|
||||||
|
softmax_logics,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
softmax_prob_out,
|
||||||
|
softmax_logics.stride(0),
|
||||||
|
softmax_logics.stride(1),
|
||||||
|
softmax_prob_out.stride(0),
|
||||||
|
softmax_prob_out.stride(1),
|
||||||
|
num_warps=num_warps,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _token_attn_2_kernel(
|
||||||
|
Prob,
|
||||||
|
V,
|
||||||
|
attn_out,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
kv_cache_loc_b_stride,
|
||||||
|
kv_cache_loc_s_stride,
|
||||||
|
prob_head_dim_stride,
|
||||||
|
prob_batch_stride,
|
||||||
|
v_batch_stride,
|
||||||
|
v_head_stride,
|
||||||
|
v_head_dim_stride,
|
||||||
|
attn_out_batch_stride,
|
||||||
|
attn_out_head_stride,
|
||||||
|
attn_out_head_dim_stride,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
current_batch = tl.program_id(0)
|
||||||
|
current_head = tl.program_id(1)
|
||||||
|
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, HEAD_DIM)
|
||||||
|
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||||
|
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||||
|
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||||
|
|
||||||
|
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
|
||||||
|
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
|
||||||
|
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
|
||||||
|
|
||||||
|
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||||
|
for start_n in range(0, current_batch_seq_len, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
p_value = tl.load(
|
||||||
|
Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
||||||
|
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
v_loc = tl.load(
|
||||||
|
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
||||||
|
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
v_value = tl.load(
|
||||||
|
V + v_offs + v_loc[:, None] * v_batch_stride,
|
||||||
|
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16)
|
||||||
|
acc += tl.sum(p_value[:, None] * v_value, 0)
|
||||||
|
|
||||||
|
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
|
||||||
|
off_o = (
|
||||||
|
current_batch * attn_out_batch_stride
|
||||||
|
+ current_head * attn_out_head_stride
|
||||||
|
+ offs_d * attn_out_head_dim_stride
|
||||||
|
)
|
||||||
|
out_ptrs = attn_out + off_o
|
||||||
|
tl.store(out_ptrs, acc)
|
||||||
|
return
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def token_attn_fwd_2(
|
||||||
|
prob,
|
||||||
|
v,
|
||||||
|
attn_out,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
):
|
||||||
|
if triton.__version__ >= "2.1.0":
|
||||||
|
BLOCK = 128
|
||||||
|
else:
|
||||||
|
BLOCK = 64
|
||||||
|
batch, head = kv_cache_loc.shape[0], v.shape[1]
|
||||||
|
grid = (batch, head)
|
||||||
|
num_warps = 4
|
||||||
|
dim = v.shape[-1]
|
||||||
|
|
||||||
|
_token_attn_2_kernel[grid](
|
||||||
|
prob,
|
||||||
|
v,
|
||||||
|
attn_out,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seqlen,
|
||||||
|
max_kv_cache_len,
|
||||||
|
kv_cache_loc.stride(0),
|
||||||
|
kv_cache_loc.stride(1),
|
||||||
|
prob.stride(0),
|
||||||
|
prob.stride(1),
|
||||||
|
v.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(2),
|
||||||
|
attn_out.stride(0),
|
||||||
|
attn_out.stride(1),
|
||||||
|
attn_out.stride(2),
|
||||||
|
HEAD_DIM=dim,
|
||||||
|
BLOCK_N=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def smooth_token_attention_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_out,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
|
alibi=None,
|
||||||
|
):
|
||||||
|
head_num = k.shape[1]
|
||||||
|
batch_size = kv_cache_seq_len.shape[0]
|
||||||
|
calcu_shape1 = (batch_size, head_num, k.shape[2])
|
||||||
|
total_token_num = k.shape[0]
|
||||||
|
|
||||||
|
att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
token_attn_fwd_1(
|
||||||
|
q.view(calcu_shape1),
|
||||||
|
k,
|
||||||
|
att_m_tensor,
|
||||||
|
q_input_scale,
|
||||||
|
k_input_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
|
alibi=alibi,
|
||||||
|
)
|
||||||
|
|
||||||
|
prob = torch.empty_like(att_m_tensor)
|
||||||
|
|
||||||
|
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||||
|
att_m_tensor = None
|
||||||
|
token_attn_fwd_2(
|
||||||
|
prob,
|
||||||
|
v,
|
||||||
|
attn_out.view(calcu_shape1),
|
||||||
|
v_input_scale,
|
||||||
|
pv_output_scale,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
prob = None
|
||||||
|
|
||||||
|
return
|
|
@ -0,0 +1,69 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_and_tokenizer(model_name):
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512)
|
||||||
|
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
|
||||||
|
model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs)
|
||||||
|
model = model.to(torch.float32)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model-name", type=str, help="model name")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-path",
|
||||||
|
type=str,
|
||||||
|
help="where to save the checkpoint",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-path",
|
||||||
|
type=str,
|
||||||
|
help="location of the calibration dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-samples", type=int, default=512)
|
||||||
|
parser.add_argument("--seq-len", type=int, default=512)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
model_path = args.model_name
|
||||||
|
dataset_path = args.dataset_path
|
||||||
|
output_path = args.output_path
|
||||||
|
num_samples = 10
|
||||||
|
seq_len = 512
|
||||||
|
|
||||||
|
model, tokenizer = build_model_and_tokenizer(model_path)
|
||||||
|
if not os.path.exists(dataset_path):
|
||||||
|
print(f"Cannot find the dataset at {args.dataset_path}")
|
||||||
|
raise FileNotFoundError
|
||||||
|
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||||
|
|
||||||
|
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
model.save_quantized(output_path, model_basename="llama-7b")
|
||||||
|
|
||||||
|
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
|
||||||
|
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
|
||||||
|
out = model.generate(**input_tokens, **generate_kwargs)
|
||||||
|
text = tokenizer.batch_decode(out)
|
||||||
|
print("out is:", text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,52 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .builder import Builder
|
||||||
|
from .utils import append_nvcc_threads
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothquantBuilder(Builder):
|
||||||
|
NAME = "cu_smoothquant"
|
||||||
|
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)
|
||||||
|
|
||||||
|
def include_dirs(self):
|
||||||
|
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def sources_files(self):
|
||||||
|
ret = [
|
||||||
|
self.csrc_abs_path(fname)
|
||||||
|
for fname in [
|
||||||
|
"smoothquant/binding.cpp",
|
||||||
|
"smoothquant/linear.cu",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def cxx_flags(self):
|
||||||
|
return ["-O3"] + self.version_dependent_macros
|
||||||
|
|
||||||
|
def nvcc_flags(self):
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10
|
||||||
|
|
||||||
|
extra_cuda_flags = [
|
||||||
|
"-v",
|
||||||
|
f"-DCUDA_ARCH={cuda_arch}",
|
||||||
|
"-std=c++17",
|
||||||
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||||
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||||
|
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||||
|
]
|
||||||
|
|
||||||
|
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
||||||
|
return append_nvcc_threads(ret)
|
||||||
|
|
||||||
|
def builder(self):
|
||||||
|
try:
|
||||||
|
super().builder()
|
||||||
|
except:
|
||||||
|
warnings.warn("build smoothquant lib not successful")
|
|
@ -0,0 +1,136 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||||
|
|
||||||
|
HAS_TRITON = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention
|
||||||
|
|
||||||
|
HAS_TORCH_INT = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH_INT = False
|
||||||
|
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||||
|
|
||||||
|
|
||||||
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||||
|
"""
|
||||||
|
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||||
|
"""
|
||||||
|
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||||
|
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||||
|
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||||
|
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||||
|
mask[mask == 0.0] = -100000000.0
|
||||||
|
mask = mask.repeat(bs, num_head, 1, 1)
|
||||||
|
keys = xk
|
||||||
|
values = xv
|
||||||
|
xq = xq.transpose(1, 2)
|
||||||
|
keys = keys.transpose(1, 2)
|
||||||
|
values = values.transpose(1, 2)
|
||||||
|
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
|
||||||
|
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
|
||||||
|
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
|
||||||
|
)
|
||||||
|
def test_llama_context_attention():
|
||||||
|
head_num = 2
|
||||||
|
seq_len = 32
|
||||||
|
head_dim = 64
|
||||||
|
dtype = torch.float
|
||||||
|
hidden_size = head_num * head_dim
|
||||||
|
|
||||||
|
smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)
|
||||||
|
|
||||||
|
smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||||
|
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||||
|
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||||
|
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||||
|
smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)
|
||||||
|
|
||||||
|
qkv_weight_scale = 1.0
|
||||||
|
|
||||||
|
ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")
|
||||||
|
|
||||||
|
smooth_attn = smooth_attn.to("cuda")
|
||||||
|
|
||||||
|
input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
|
||||||
|
input_scale = 1 / 20.0
|
||||||
|
|
||||||
|
output = torch.matmul(input.to(torch.float) * input_scale, ones)
|
||||||
|
qkv_max_out = torch.max(torch.abs(output)) / 127
|
||||||
|
smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||||
|
smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||||
|
smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||||
|
|
||||||
|
q = smooth_attn.q_proj(input)
|
||||||
|
k = smooth_attn.k_proj(input)
|
||||||
|
v = smooth_attn.v_proj(input)
|
||||||
|
|
||||||
|
cos_shape = (seq_len, head_dim // 2)
|
||||||
|
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
in_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||||
|
out_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||||
|
int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||||
|
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||||
|
|
||||||
|
q = q.to(torch.float) * out_scale
|
||||||
|
k = k.to(torch.float) * out_scale
|
||||||
|
v = v.to(torch.float) * out_scale
|
||||||
|
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
|
||||||
|
attn_out_max = torch.max(torch.abs(torch_out)) / 127
|
||||||
|
|
||||||
|
output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
|
||||||
|
smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
|
||||||
|
smooth_attn.k_output_scale = torch.tensor(qkv_max_out)
|
||||||
|
|
||||||
|
smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
|
||||||
|
smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||||
|
smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||||
|
|
||||||
|
smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
|
||||||
|
smooth_attn.out_proj.a = torch.tensor([attn_out_max])
|
||||||
|
|
||||||
|
torch_out = (
|
||||||
|
(torch_out / smooth_attn.attn_output_scale)
|
||||||
|
.round()
|
||||||
|
.clamp(-128, 127)
|
||||||
|
.to(torch.int8)
|
||||||
|
.view(-1, seq_len, head_num * head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_out = smooth_attn.out_proj(torch_out)
|
||||||
|
torch_out = torch_out.to(torch.float)
|
||||||
|
|
||||||
|
smooth_attn = smooth_attn.to("cuda")
|
||||||
|
smooth_out, _, _ = smooth_attn(input, (cos, sin))
|
||||||
|
smooth_out = smooth_out.to(torch.float)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
|
||||||
|
), "outputs from triton and torch are not matched"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_llama_context_attention()
|
|
@ -0,0 +1,84 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||||
|
|
||||||
|
smoothquant_cuda = SmoothquantBuilder().load()
|
||||||
|
HAS_SMOOTHQUANT_CUDA = True
|
||||||
|
except:
|
||||||
|
warnings.warn("CUDA smoothquant linear is not installed")
|
||||||
|
HAS_SMOOTHQUANT_CUDA = False
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
|
||||||
|
|
||||||
|
HAS_TORCH_INT = True
|
||||||
|
except:
|
||||||
|
HAS_TORCH_INT = False
|
||||||
|
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
|
||||||
|
gate_out = torch.mm(x, gate_proj)
|
||||||
|
silu = torch.nn.SiLU()
|
||||||
|
gate_out = silu(gate_out)
|
||||||
|
up_out = torch.mm(x, up_proj)
|
||||||
|
|
||||||
|
o_out = gate_out * up_out
|
||||||
|
|
||||||
|
max_up = torch.max(torch.abs(o_out))
|
||||||
|
min_up = torch.min(torch.abs(o_out))
|
||||||
|
|
||||||
|
torch_out = torch.mm(o_out, down_proj)
|
||||||
|
|
||||||
|
return (torch_out, max_up, min_up)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
|
||||||
|
reason="smoothquant linear not installed properly or not install torch_int",
|
||||||
|
)
|
||||||
|
def test_llama_mlp():
|
||||||
|
hidden_size = 256
|
||||||
|
intermediate_size = 512
|
||||||
|
|
||||||
|
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)
|
||||||
|
|
||||||
|
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
smooth_mlp.up_proj.weight = torch.randint(
|
||||||
|
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
|
||||||
|
)
|
||||||
|
smooth_mlp.down_proj.weight = torch.randint(
|
||||||
|
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.ones((1, 256), dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
torch_out, max_inter, min_inter = torch_llama_mlp(
|
||||||
|
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
|
||||||
|
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||||
|
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||||
|
x.to(torch.float),
|
||||||
|
)
|
||||||
|
|
||||||
|
smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127)
|
||||||
|
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
|
||||||
|
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
|
||||||
|
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))
|
||||||
|
|
||||||
|
smooth_out = smooth_mlp(x)
|
||||||
|
|
||||||
|
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_llama_mlp()
|
|
@ -0,0 +1,39 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||||
|
|
||||||
|
smoothquant_cuda = SmoothquantBuilder().load()
|
||||||
|
HAS_SMOOTHQUANT_CUDA = True
|
||||||
|
except:
|
||||||
|
warnings.warn("CUDA smoothquant linear is not installed")
|
||||||
|
HAS_SMOOTHQUANT_CUDA = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not HAS_SMOOTHQUANT_CUDA,
|
||||||
|
reason="smoothquant linear not installed properly",
|
||||||
|
)
|
||||||
|
def test_linear():
|
||||||
|
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
|
||||||
|
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
|
||||||
|
c = torch.rand(256, dtype=torch.float, device="cuda")
|
||||||
|
|
||||||
|
alpha = 1 / 127
|
||||||
|
beta = 1.0
|
||||||
|
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
|
||||||
|
|
||||||
|
silu = torch.nn.SiLU()
|
||||||
|
torch_out = silu(torch_out)
|
||||||
|
|
||||||
|
b = b.transpose(0, 1).contiguous()
|
||||||
|
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
|
||||||
|
|
||||||
|
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_linear()
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||||
|
|
||||||
|
HAS_TRITON = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_rotary_emb(x, cos, sin):
|
||||||
|
seq_len, h, dim = x.shape
|
||||||
|
x0 = x[:, :, 0 : dim // 2]
|
||||||
|
x1 = x[:, :, dim // 2 : dim]
|
||||||
|
cos = cos.view((seq_len, 1, dim // 2))
|
||||||
|
sin = sin.view((seq_len, 1, dim // 2))
|
||||||
|
o0 = x0 * cos - x1 * sin
|
||||||
|
o1 = x0 * sin + x1 * cos
|
||||||
|
return torch.cat((o0, o1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||||
|
)
|
||||||
|
def test_rotary_emb():
|
||||||
|
SEQ_LEN = 1
|
||||||
|
HEAD_NUM = 32
|
||||||
|
HEAD_DIM = 128
|
||||||
|
dtype = torch.float
|
||||||
|
# create data
|
||||||
|
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
# forward pass
|
||||||
|
y_torch = torch_rotary_emb(x, cos, sin)
|
||||||
|
|
||||||
|
input_scale = torch.max(torch.abs(x)) / 127
|
||||||
|
output_scale = torch.max(torch.abs(y_torch)) / 127
|
||||||
|
|
||||||
|
x = x / input_scale
|
||||||
|
x = x.to(torch.int8)
|
||||||
|
|
||||||
|
int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item())
|
||||||
|
y_triton = x.to(torch.float) * output_scale
|
||||||
|
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rotary_emb()
|
Loading…
Reference in New Issue