mirror of https://github.com/hpcaitech/ColossalAI
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant licensepull/4918/head
parent
a0684e7bd6
commit
611a5a80ca
50
LICENSE
50
LICENSE
|
@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
---------------- LICENSE FOR torch-int ----------------
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Guangxuan Xiao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
---------------- LICENSE FOR smoothquant ----------------
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 MIT HAN Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
try:
|
||||
import torch_int
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
raise ImportError(
|
||||
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
|
||||
)
|
||||
|
||||
if HAS_TORCH_INT:
|
||||
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
|
|
@ -0,0 +1,482 @@
|
|||
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
|
||||
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
|
||||
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from os.path import isdir, isfile, join
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from safetensors.torch import save_file as safe_save
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.utils.generic import ContextManagers
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
SUPPORTED_MODELS = ["llama"]
|
||||
|
||||
|
||||
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
|
||||
layer_type: str = None
|
||||
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.model_type = self.model.config.model_type
|
||||
self._quantized = quantized
|
||||
self.config = self.model.config
|
||||
self.cache_manager = None
|
||||
self.max_total_token_num = 0
|
||||
|
||||
@property
|
||||
def quantized(self):
|
||||
return self._quantized
|
||||
|
||||
def init_cache_manager(self, max_total_token_num=2048):
|
||||
if self.config.model_type == "llama":
|
||||
head_num = self.config.num_key_value_heads
|
||||
layer_num = self.config.num_hidden_layers
|
||||
head_dim = self.config.hidden_size // head_num
|
||||
|
||||
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
|
||||
self.max_total_token_num = max_total_token_num
|
||||
|
||||
def init_batch_state(self, max_output_len=256, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
batch_size = len(input_ids)
|
||||
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
start_index = 0
|
||||
max_len_in_batch = -1
|
||||
|
||||
for i in range(batch_size):
|
||||
seq_len = len(input_ids[i])
|
||||
seq_lengths[i] = seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += seq_len
|
||||
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
|
||||
|
||||
if "max_total_token_num" in kwargs.keys():
|
||||
max_total_token_num = kwargs["max_total_token_num"]
|
||||
self.init_cache_manager(max_total_token_num)
|
||||
|
||||
if "max_new_tokens" in kwargs.keys():
|
||||
max_output_len = kwargs["max_new_tokens"]
|
||||
|
||||
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
|
||||
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
|
||||
warnings.warn(f"reset max tokens to {max_total_token_num}")
|
||||
self.init_cache_manager(max_total_token_num)
|
||||
|
||||
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
|
||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
|
||||
batch_infer_state.block_loc = block_loc
|
||||
batch_infer_state.decode_layer_id = 0
|
||||
batch_infer_state.past_key_values_len = 0
|
||||
batch_infer_state.is_context_stage = True
|
||||
batch_infer_state.set_cache_manager(self.cache_manager)
|
||||
batch_infer_state.cache_manager.free_all()
|
||||
return batch_infer_state
|
||||
|
||||
@abstractmethod
|
||||
@torch.inference_mode()
|
||||
def quantize(
|
||||
self,
|
||||
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
|
||||
):
|
||||
if self.quantized:
|
||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""shortcut for model.generate"""
|
||||
|
||||
batch_infer_state = self.init_batch_state(**kwargs)
|
||||
if self.config.model_type == "llama":
|
||||
setattr(self.model.model, "infer_state", batch_infer_state)
|
||||
|
||||
with torch.inference_mode():
|
||||
return self.model.generate(**kwargs)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
"""shortcut for model.prepare_inputs_for_generation"""
|
||||
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
|
||||
for text in tqdm(dataset):
|
||||
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
|
||||
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
|
||||
pbar = tqdm(dataset)
|
||||
for text in pbar:
|
||||
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
mean_scale = np.mean([v["input"] for v in act_dict.values()])
|
||||
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
|
||||
|
||||
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = {}
|
||||
|
||||
def stat_tensor(name, tensor):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
|
||||
if name in act_scales:
|
||||
act_scales[name] = torch.max(act_scales[name], comming_max)
|
||||
else:
|
||||
act_scales[name] = comming_max
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x)
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
|
||||
|
||||
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
ln.weight.div_(scales)
|
||||
if hasattr(ln, "bias"):
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
|
||||
@classmethod
|
||||
def create_quantized_model(model):
|
||||
raise NotImplementedError("Not implement create_quantized_model method")
|
||||
|
||||
def save_quantized(
|
||||
self,
|
||||
save_dir: str,
|
||||
model_basename: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""save quantized model and configs to local disk"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not self.quantized:
|
||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||
|
||||
self.model.to("cpu")
|
||||
|
||||
model_base_name = model_basename # or f"smooth-"
|
||||
if use_safetensors:
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
if safetensors_metadata is None:
|
||||
safetensors_metadata = {}
|
||||
elif not isinstance(safetensors_metadata, dict):
|
||||
raise TypeError("safetensors_metadata must be a dictionary.")
|
||||
else:
|
||||
print(f"Received safetensors_metadata: {safetensors_metadata}")
|
||||
new_safetensors_metadata = {}
|
||||
converted_keys = False
|
||||
for key, value in safetensors_metadata.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
converted_keys = True
|
||||
try:
|
||||
new_key = str(key)
|
||||
new_value = str(value)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
|
||||
)
|
||||
if new_key in new_safetensors_metadata:
|
||||
print(
|
||||
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
|
||||
)
|
||||
new_safetensors_metadata[new_key] = new_value
|
||||
safetensors_metadata = new_safetensors_metadata
|
||||
if converted_keys:
|
||||
print(
|
||||
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
|
||||
)
|
||||
|
||||
# Format is required to enable Accelerate to load the metadata
|
||||
# otherwise it raises an OSError
|
||||
safetensors_metadata["format"] = "pt"
|
||||
|
||||
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
||||
else:
|
||||
model_save_name = model_base_name + ".bin"
|
||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_dir: str,
|
||||
use_safetensors: bool = False,
|
||||
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""alias of save_quantized"""
|
||||
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
max_memory: Optional[dict] = None,
|
||||
trust_remote_code: bool = False,
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
**model_init_kwargs,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = model_init_kwargs.pop("cache_dir", None)
|
||||
force_download = model_init_kwargs.pop("force_download", False)
|
||||
resume_download = model_init_kwargs.pop("resume_download", False)
|
||||
proxies = model_init_kwargs.pop("proxies", None)
|
||||
local_files_only = model_init_kwargs.pop("local_files_only", False)
|
||||
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
|
||||
revision = model_init_kwargs.pop("revision", None)
|
||||
subfolder = model_init_kwargs.pop("subfolder", "")
|
||||
model_init_kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
# enforce some values despite user specified
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
model_init_kwargs["trust_remote_code"] = trust_remote_code
|
||||
if max_memory:
|
||||
if "disk" in max_memory:
|
||||
raise NotImplementedError("disk offload not support yet.")
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
model.tie_weights()
|
||||
|
||||
max_memory = accelerate.utils.get_balanced_memory(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type],
|
||||
dtype=model_init_kwargs["torch_dtype"],
|
||||
low_zero=False,
|
||||
)
|
||||
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=[cls.layer_type],
|
||||
dtype=model_init_kwargs["torch_dtype"],
|
||||
)
|
||||
model_init_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
del model
|
||||
else:
|
||||
model_init_kwargs["device_map"] = None
|
||||
model_init_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
|
||||
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
for key in seq_len_keys:
|
||||
if key in model_config:
|
||||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
model.eval()
|
||||
|
||||
return cls(model, False)
|
||||
|
||||
@classmethod
|
||||
def from_quantized(
|
||||
cls,
|
||||
model_name_or_path: Optional[str],
|
||||
model_basename: Optional[str] = None,
|
||||
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
|
||||
max_memory: Optional[dict] = None,
|
||||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
use_safetensors: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""load quantized model from local disk"""
|
||||
|
||||
# Parameters related to loading from Hugging Face Hub
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
cached_file_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
|
||||
)
|
||||
|
||||
if config.model_type not in SUPPORTED_MODELS:
|
||||
raise TypeError(f"{config.model_type} isn't supported yet.")
|
||||
|
||||
extensions = []
|
||||
if use_safetensors:
|
||||
extensions.append(".safetensors")
|
||||
else:
|
||||
extensions += [".bin", ".pt"]
|
||||
|
||||
model_name_or_path = str(model_name_or_path)
|
||||
is_local = isdir(model_name_or_path)
|
||||
|
||||
resolved_archive_file = None
|
||||
if is_local:
|
||||
model_save_name = join(model_name_or_path, model_basename)
|
||||
for ext in extensions:
|
||||
if isfile(model_save_name + ext):
|
||||
resolved_archive_file = model_save_name + ext
|
||||
break
|
||||
else: # remote
|
||||
for ext in extensions:
|
||||
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
|
||||
if resolved_archive_file is not None:
|
||||
break
|
||||
|
||||
if resolved_archive_file is None: # Could not find a model file to use
|
||||
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
|
||||
|
||||
model_save_name = resolved_archive_file
|
||||
|
||||
# == step2: convert model to quantized-model (replace Linear) == #
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
||||
transformers.modeling_utils._init_weights = False
|
||||
|
||||
init_contexts = [no_init_weights()]
|
||||
if low_cpu_mem_usage:
|
||||
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
|
||||
)
|
||||
cls.create_quantized_model(model)
|
||||
model.tie_weights()
|
||||
|
||||
# == step3: load checkpoint to quantized-model == #
|
||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
|
||||
)
|
||||
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
if any([k in model_config for k in seq_len_keys]):
|
||||
for key in seq_len_keys:
|
||||
if key in model_config:
|
||||
model.seqlen = model_config[key]
|
||||
break
|
||||
else:
|
||||
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
|
||||
model.seqlen = 4096
|
||||
|
||||
return cls(
|
||||
model,
|
||||
True,
|
||||
)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
return getattr(self.model, item)
|
||||
|
||||
|
||||
__all__ = ["BaseSmoothForCausalLM"]
|
|
@ -0,0 +1,177 @@
|
|||
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
|
||||
|
||||
import torch
|
||||
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
|
||||
from torch_int.functional.quantization import quantize_per_tensor_absmax
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except ImportError:
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
raise ImportError("CUDA smoothquant linear is not installed")
|
||||
|
||||
|
||||
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
if module.bias is not None:
|
||||
int8_module.bias.data.copy_(module.bias.to(torch.float))
|
||||
int8_module.a = alpha
|
||||
return int8_module
|
||||
|
||||
|
||||
class W8A8B8O8Linear(torch.nn.Module):
|
||||
# For qkv_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
self.register_buffer("b", torch.tensor(beta))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale, output_scale):
|
||||
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale / output_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
|
||||
if module.bias is not None:
|
||||
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
|
||||
int8_module.bias = int8_bias
|
||||
beta = bias_scale / output_scale
|
||||
int8_module.b = beta
|
||||
|
||||
return int8_module
|
||||
|
||||
|
||||
class W8A8BFP32OFP32Linear(torch.nn.Module):
|
||||
# For fc2 and out_proj
|
||||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(self.out_features, self.in_features),
|
||||
dtype=torch.int8,
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
|
||||
)
|
||||
self.register_buffer("a", torch.tensor(alpha))
|
||||
|
||||
def _apply(self, fn):
|
||||
# prevent the bias from being converted to half
|
||||
super()._apply(fn)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.weight = self.weight.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(*args, **kwargs)
|
||||
self.bias = self.bias.to(torch.float32)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
|
||||
y = y.view(*x_shape[:-1], -1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.Linear, input_scale):
|
||||
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
|
||||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
|
||||
alpha = input_scale * weight_scale
|
||||
int8_module.weight = int8_weight
|
||||
int8_module.a = alpha
|
||||
int8_module.input_scale = input_scale
|
||||
int8_module.weight_scale = weight_scale
|
||||
|
||||
if module.bias is not None:
|
||||
int8_module.bias = module.bias.to(torch.float32)
|
||||
|
||||
return int8_module
|
|
@ -0,0 +1,846 @@
|
|||
import math
|
||||
import os
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LLAMA_INPUTS_DOCSTRING,
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaMLP,
|
||||
LlamaRotaryEmbedding,
|
||||
repeat_kv,
|
||||
rotate_half,
|
||||
)
|
||||
from transformers.utils import add_start_docstrings_to_model_forward
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import (
|
||||
copy_kv_cache_to_dest,
|
||||
int8_rotary_embedding_fwd,
|
||||
smooth_llama_context_attn_fwd,
|
||||
smooth_token_attention_fwd,
|
||||
)
|
||||
|
||||
from .base_model import BaseSmoothForCausalLM
|
||||
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
||||
|
||||
|
||||
class LLamaSmoothquantAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
|
||||
self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
|
||||
self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
|
||||
|
||||
self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
|
||||
self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
|
||||
|
||||
self.register_buffer("q_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("k_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("v_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("out_input_scale", torch.tensor([1.0]))
|
||||
self.register_buffer("attn_input_scale", torch.tensor([1.0]))
|
||||
|
||||
self._init_rope()
|
||||
self.num_key_value_heads = num_heads
|
||||
|
||||
def _init_rope(self):
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000.0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
module: LlamaAttention,
|
||||
attn_input_scale: float,
|
||||
q_output_scale: float,
|
||||
k_output_scale: float,
|
||||
v_output_scale: float,
|
||||
q_rotary_output_scale: float,
|
||||
k_rotary_output_scale: float,
|
||||
out_input_scale: float,
|
||||
):
|
||||
int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
|
||||
|
||||
int8_module.attn_input_scale = torch.tensor([attn_input_scale])
|
||||
|
||||
int8_module.q_output_scale = torch.tensor([q_output_scale])
|
||||
int8_module.k_output_scale = torch.tensor([k_output_scale])
|
||||
int8_module.v_output_scale = torch.tensor([v_output_scale])
|
||||
|
||||
int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
|
||||
int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
|
||||
|
||||
int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
|
||||
int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
|
||||
int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
|
||||
int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
|
||||
|
||||
int8_module.out_input_scale = torch.tensor([out_input_scale])
|
||||
|
||||
return int8_module
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
cos = rotary_emb[0]
|
||||
sin = rotary_emb[1]
|
||||
|
||||
int8_rotary_embedding_fwd(
|
||||
query_states.view(-1, self.num_heads, self.head_dim),
|
||||
cos,
|
||||
sin,
|
||||
self.q_output_scale.item(),
|
||||
self.q_rotary_output_scale.item(),
|
||||
)
|
||||
int8_rotary_embedding_fwd(
|
||||
key_states.view(-1, self.num_heads, self.head_dim),
|
||||
cos,
|
||||
sin,
|
||||
self.k_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
)
|
||||
|
||||
# NOTE might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
# since we won't return past_key_value_cache right now
|
||||
if infer_state.decode_layer_id == 0: # once per model.forward
|
||||
infer_state.cache_manager.past_key_values_length += q_len # seq_len
|
||||
|
||||
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
return
|
||||
|
||||
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation
|
||||
|
||||
# copy key and value calculated in current step to memory manager
|
||||
_copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
smooth_llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
self.q_rotary_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
self.v_output_scale.item(),
|
||||
self.out_input_scale.item(),
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
q_len,
|
||||
)
|
||||
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_k.copy_(key_states)
|
||||
cache_v.copy_(value_states)
|
||||
else:
|
||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||
_copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
smooth_token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
self.q_rotary_output_scale.item(),
|
||||
self.k_rotary_output_scale.item(),
|
||||
self.v_output_scale.item(),
|
||||
self.out_input_scale.item(),
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
class LlamaLayerNormQ(torch.nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.input_scale = 1.0
|
||||
self.variance_epsilon = eps
|
||||
self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
|
||||
ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
|
||||
return ln_output_int8
|
||||
|
||||
@staticmethod
|
||||
def from_float(module: torch.nn.LayerNorm, output_scale: float):
|
||||
assert module.weight.shape[0] == module.weight.numel()
|
||||
q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
|
||||
q_module.weight = module.weight / output_scale
|
||||
return q_module
|
||||
|
||||
|
||||
class LlamaSmoothquantMLP(nn.Module):
|
||||
def __init__(self, intermediate_size, hidden_size):
|
||||
super().__init__()
|
||||
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
|
||||
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
|
||||
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
|
||||
self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
mlp_module: LlamaMLP,
|
||||
gate_proj_input_scale: float,
|
||||
up_proj_input_scale: float,
|
||||
down_proj_input_scale: float,
|
||||
):
|
||||
int8_module = LlamaSmoothquantMLP(
|
||||
mlp_module.intermediate_size,
|
||||
mlp_module.hidden_size,
|
||||
)
|
||||
|
||||
int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
|
||||
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
|
||||
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
|
||||
int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
|
||||
return int8_module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
x_shape = hidden_states.shape
|
||||
gate_out = self.gate_proj(hidden_states)
|
||||
up_out = self.up_proj(hidden_states)
|
||||
inter_out = gate_out * up_out
|
||||
inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
|
||||
down_out = self.down_proj(inter_out)
|
||||
down_out = down_out.view(*x_shape[:-1], -1)
|
||||
return down_out
|
||||
|
||||
|
||||
class LlamaSmoothquantDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
|
||||
|
||||
self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
|
||||
self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@staticmethod
|
||||
def pack(
|
||||
module: LlamaDecoderLayer,
|
||||
attn_input_scale: float,
|
||||
q_output_scale: float,
|
||||
k_output_scale: float,
|
||||
v_output_scale: float,
|
||||
q_rotary_output_scale: float,
|
||||
k_rotary_output_scale: float,
|
||||
out_input_scale: float,
|
||||
gate_input_scale: float,
|
||||
up_input_scale: float,
|
||||
down_input_scale: float,
|
||||
):
|
||||
config = module.self_attn.config
|
||||
int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
|
||||
|
||||
int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
|
||||
int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
|
||||
module.self_attn,
|
||||
attn_input_scale,
|
||||
q_output_scale,
|
||||
k_output_scale,
|
||||
v_output_scale,
|
||||
q_rotary_output_scale,
|
||||
k_rotary_output_scale,
|
||||
out_input_scale,
|
||||
)
|
||||
|
||||
int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
|
||||
module.post_attention_layernorm, gate_input_scale
|
||||
)
|
||||
|
||||
int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
|
||||
module.mlp,
|
||||
gate_input_scale,
|
||||
up_input_scale,
|
||||
down_input_scale,
|
||||
)
|
||||
|
||||
return int8_decoder_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_emb: Tuple[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, None, None
|
||||
|
||||
|
||||
class LlamaApplyRotary(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
return x_embed
|
||||
|
||||
|
||||
def llama_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
|
||||
key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def init_to_get_rotary(config, base=10000, use_elem=False):
|
||||
"""
|
||||
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||
Args:
|
||||
base : calculation arg
|
||||
use_elem : activated when using chatglm-based models
|
||||
"""
|
||||
config.head_dim_ = config.hidden_size // config.num_attention_heads
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
|
||||
|
||||
if hasattr(config, "max_sequence_length"):
|
||||
max_seq_len = config.max_sequence_length
|
||||
elif hasattr(config, "max_position_embeddings"):
|
||||
max_seq_len = config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
try:
|
||||
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
|
||||
assert ntk_alpha >= 1
|
||||
if ntk_alpha > 1:
|
||||
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||
max_seq_len *= ntk_alpha
|
||||
base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
|
||||
except:
|
||||
pass
|
||||
|
||||
n_elem = config.head_dim_
|
||||
if use_elem:
|
||||
n_elem //= 2
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
_cos_cached = torch.cos(freqs).to(torch.float)
|
||||
_sin_cached = torch.sin(freqs).to(torch.float)
|
||||
return _cos_cached, _sin_cached
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
infer_state = self.infer_state
|
||||
|
||||
if past_key_values is not None:
|
||||
# NOT READY FOR PRIME TIME
|
||||
# dummy but work, revise it
|
||||
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||
# past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if infer_state.is_context_stage:
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(
|
||||
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||
)
|
||||
else:
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(
|
||||
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
|
||||
)
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
raise NotImplementedError("not implement gradient_checkpointing and training options ")
|
||||
|
||||
if past_key_values_length == 0:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
infer_state.decode_layer_id = 0
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
rotary_emb=(position_cos, position_sin),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
infer_state.decode_layer_id += 1
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
infer_state.is_context_stage = False
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
|
||||
layer_type = "LlamaDecoderLayer"
|
||||
|
||||
def __init__(self, model: PreTrainedModel, quantized: bool = False):
|
||||
super().__init__(model, quantized)
|
||||
|
||||
def get_act_dict(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512,
|
||||
):
|
||||
llama_model = self.model
|
||||
|
||||
llama_model.eval()
|
||||
device = next(llama_model.parameters()).device
|
||||
# print("model:", llama_model)
|
||||
act_dict = defaultdict(dict)
|
||||
|
||||
def stat_io_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
if name not in act_dict or "input" not in act_dict[name]:
|
||||
act_dict[name]["input"] = x.detach().abs().max().item()
|
||||
else:
|
||||
act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
|
||||
if isinstance(y, tuple):
|
||||
y = y[0]
|
||||
if name not in act_dict or "output" not in act_dict[name]:
|
||||
act_dict[name]["output"] = y.detach().abs().max().item()
|
||||
else:
|
||||
act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
|
||||
|
||||
for name, m in llama_model.named_modules():
|
||||
if isinstance(m, LlamaAttention):
|
||||
setattr(m, "q_apply_rotary", LlamaApplyRotary())
|
||||
setattr(m, "k_apply_rotary", LlamaApplyRotary())
|
||||
m.forward = types.MethodType(llama_decoder_layer_forward, m)
|
||||
|
||||
hooks = []
|
||||
for name, m in llama_model.named_modules():
|
||||
if isinstance(m, LlamaApplyRotary):
|
||||
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
|
||||
|
||||
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
return act_dict
|
||||
|
||||
def smooth_fn(self, scales, alpha=0.5):
|
||||
model = self.model
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaDecoderLayer):
|
||||
attn_ln = module.input_layernorm
|
||||
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
|
||||
qkv_input_scales = scales[name + ".self_attn.q_proj"]
|
||||
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
|
||||
|
||||
def create_quantized_model(model):
|
||||
llama_config = model.config
|
||||
for i, layer in enumerate(model.model.layers):
|
||||
model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
|
||||
|
||||
model.model.forward = types.MethodType(llama_model_forward, model.model)
|
||||
cos, sin = init_to_get_rotary(llama_config)
|
||||
model.model.register_buffer("_cos_cached", cos)
|
||||
model.model.register_buffer("_sin_cached", sin)
|
||||
|
||||
def quantized(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512,
|
||||
alpha=0.5,
|
||||
):
|
||||
llama_model = self.model
|
||||
llama_config = llama_model.config
|
||||
|
||||
act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
|
||||
|
||||
self.smooth_fn(act_scales, alpha)
|
||||
|
||||
act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
|
||||
decoder_layer_scales = []
|
||||
|
||||
for idx in range(llama_config.num_hidden_layers):
|
||||
scale_dict = {}
|
||||
scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
|
||||
scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
|
||||
scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
|
||||
scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
|
||||
|
||||
scale_dict["q_rotary_output_scale"] = (
|
||||
act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
|
||||
)
|
||||
scale_dict["k_rotary_output_scale"] = (
|
||||
act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
|
||||
)
|
||||
|
||||
scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
|
||||
|
||||
scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
|
||||
scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
|
||||
scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
|
||||
|
||||
decoder_layer_scales.append(scale_dict)
|
||||
|
||||
for i, layer in enumerate(llama_model.model.layers):
|
||||
orig_layer = layer
|
||||
llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
|
||||
|
||||
llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
|
||||
|
||||
cos, sin = init_to_get_rotary(llama_config)
|
||||
llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
|
||||
llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))
|
|
@ -0,0 +1,8 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "linear.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
|
||||
"Linear SiLU (INT8)");
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
|
||||
|
||||
#include "linear.h"
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/half.h>
|
||||
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination_silu.h>
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <iostream>
|
||||
#include <torch/torch.h>
|
||||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||
torch::Tensor weight, // INT8
|
||||
torch::Tensor bias, // FP32
|
||||
float alpha, // FP32
|
||||
float beta // FP32
|
||||
) {
|
||||
auto M = input.size(0);
|
||||
auto N = weight.size(0);
|
||||
auto K = input.size(1);
|
||||
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementComputeEpilogue = float;
|
||||
using ElementInputA = int8_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||
|
||||
// The code section below describes matrix layout of input and output
|
||||
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
|
||||
// for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
#if CUDA_ARCH >= 800
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This
|
||||
// becomes the vector width of math
|
||||
// instructions in epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||
// function
|
||||
>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
#elif CUDA_ARCH >= 750
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This
|
||||
// becomes the vector width of math
|
||||
// instructions in epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||
// function
|
||||
>;
|
||||
|
||||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||
DefaultGemmCfg::InstructionShape,
|
||||
EpilogueOp>;
|
||||
#elif CUDA_ARCH >= 700
|
||||
#define USE_TORCH_SILU
|
||||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||
DefaultGemmCfg::InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
|
||||
#else
|
||||
#error "Unsupported cuda arch"
|
||||
#endif
|
||||
|
||||
auto input_size = cutlass::MatrixCoord(M, K);
|
||||
auto weight_size = cutlass::MatrixCoord(K, N);
|
||||
auto output_size = cutlass::MatrixCoord(M, N);
|
||||
|
||||
auto device = input.device();
|
||||
// use the broadcasted bias as the output
|
||||
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
|
||||
|
||||
// constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
cutlass::gemm::GemmCoord problem_size(M, N, K);
|
||||
|
||||
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
|
||||
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
|
||||
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
|
||||
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
|
||||
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
|
||||
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
problem_size, // <- problem size of matrix multiplication
|
||||
input_ref, // <- reference to matrix A on device
|
||||
weight_ref, // <- reference to matrix B on device
|
||||
out_ref, // <- reference to matrix C on device
|
||||
out_ref, // <- reference to matrix D on device
|
||||
{alpha, beta}, 1};
|
||||
Gemm gemm_op;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix
|
||||
// multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot implement");
|
||||
}
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot initialize");
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot run");
|
||||
}
|
||||
#ifdef USE_TORCH_SILU
|
||||
#undef USE_TORCH_SILU
|
||||
out = torch::silu(out);
|
||||
#endif
|
||||
return out;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
#include <torch/torch.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||
torch::Tensor weight, // INT8
|
||||
torch::Tensor bias, // FP32
|
||||
float alpha, // FP32
|
||||
float beta // FP32
|
||||
);
|
|
@ -13,8 +13,10 @@ if HAS_TRITON:
|
|||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from .fused_layernorm import layer_norm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
|
||||
from .rms_norm import rmsnorm_forward
|
||||
from .rotary_embedding_kernel import rotary_embedding_fwd
|
||||
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
|
||||
from .softmax import softmax
|
||||
from .token_attention_kernel import token_attention_fwd
|
||||
|
||||
|
@ -29,4 +31,7 @@ if HAS_TRITON:
|
|||
"rotary_embedding_fwd",
|
||||
"token_attention_fwd",
|
||||
"gptq_fused_linear_triton",
|
||||
"int8_rotary_embedding_fwd",
|
||||
"smooth_llama_context_attn_fwd",
|
||||
"smooth_token_attention_fwd",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rotary_kernel(
|
||||
q,
|
||||
input_scale,
|
||||
output_scale,
|
||||
Cos,
|
||||
Sin,
|
||||
q_bs_stride,
|
||||
q_h_stride,
|
||||
q_d_stride,
|
||||
cos_bs_stride,
|
||||
cos_d_stride,
|
||||
total_len,
|
||||
HEAD_NUM: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
BLOCK_SEQ: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
current_head_index = tl.program_id(0)
|
||||
current_seq_index = tl.program_id(1)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||
|
||||
off_q0 = (
|
||||
current_seq_range[:, None, None] * q_bs_stride
|
||||
+ current_head_range[None, :, None] * q_h_stride
|
||||
+ dim_range0[None, None, :] * q_d_stride
|
||||
)
|
||||
off_q1 = (
|
||||
current_seq_range[:, None, None] * q_bs_stride
|
||||
+ current_head_range[None, :, None] * q_h_stride
|
||||
+ dim_range1[None, None, :] * q_d_stride
|
||||
)
|
||||
|
||||
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
||||
|
||||
q0 = tl.load(
|
||||
q + off_q0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0,
|
||||
)
|
||||
q1 = tl.load(
|
||||
q + off_q1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
|
||||
q0 = q0.to(tl.float32) * input_scale
|
||||
q1 = q1.to(tl.float32) * input_scale
|
||||
|
||||
out0 = (q0 * cos - q1 * sin) / output_scale
|
||||
out1 = (q0 * sin + q1 * cos) / output_scale
|
||||
|
||||
out0 = out0.to(tl.int8)
|
||||
out1 = out1.to(tl.int8)
|
||||
|
||||
tl.store(
|
||||
q + off_q0,
|
||||
out0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
)
|
||||
tl.store(
|
||||
q + off_q1,
|
||||
out1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
|
||||
total_len = q.shape[0]
|
||||
head_num = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SEQ = 32
|
||||
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
_rotary_kernel[grid](
|
||||
q,
|
||||
input_scale,
|
||||
output_scale,
|
||||
cos,
|
||||
sin,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
cos.stride(0),
|
||||
cos.stride(1),
|
||||
total_len,
|
||||
HEAD_NUM=head_num,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_SEQ=BLOCK_SEQ,
|
||||
HEAD_DIM=head_dim,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,652 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
"""
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_tmp_b,
|
||||
stride_tmp_h,
|
||||
stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
load_p_ptrs = (
|
||||
Q
|
||||
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
v = v.to(tl.float16) * v_input_scale.to(tl.float16)
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_llama_context_attn_fwd(
|
||||
q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len
|
||||
):
|
||||
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
BLOCK_N = 128
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_context_flash_attention_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
tmp.stride(0),
|
||||
tmp.stride(1),
|
||||
tmp.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_kernel(
|
||||
Q,
|
||||
K,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
sm_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
q_batch_stride,
|
||||
q_head_stride,
|
||||
q_head_dim_stride,
|
||||
k_batch_stride,
|
||||
k_head_stride,
|
||||
k_head_dim_stride,
|
||||
attn_head_stride,
|
||||
attn_batch_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = max_kv_cache_len
|
||||
|
||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_alibi_kernel(
|
||||
Q,
|
||||
K,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
sm_scale,
|
||||
alibi,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
q_batch_stride,
|
||||
q_head_stride,
|
||||
q_head_dim_stride,
|
||||
k_batch_stride,
|
||||
k_head_stride,
|
||||
k_head_dim_stride,
|
||||
attn_head_stride,
|
||||
attn_batch_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = max_kv_cache_len
|
||||
|
||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
alibi_m = tl.load(alibi + current_head)
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
q = q.to(tl.float16) * q_input_scale.to(tl.float16)
|
||||
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
k = k.to(tl.float16) * k_input_scale.to(tl.float16)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
|
||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_fwd_1(
|
||||
q,
|
||||
k,
|
||||
attn_out,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
alibi=None,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
|
||||
assert q_head_dim == k_head_dim
|
||||
assert k_head_dim in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (k_head_dim**0.5)
|
||||
|
||||
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
|
||||
|
||||
num_warps = 4 if k_head_dim <= 64 else 8
|
||||
num_warps = 2
|
||||
|
||||
if alibi is not None:
|
||||
_token_attn_1_alibi_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
sm_scale,
|
||||
alibi,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
HEAD_DIM=k_head_dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
else:
|
||||
_token_attn_1_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
sm_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
HEAD_DIM=k_head_dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_softmax_fwd(
|
||||
softmax_logics,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
softmax_prob_out,
|
||||
logics_head_dim_stride,
|
||||
logics_batch_stride,
|
||||
prob_head_dim_stride,
|
||||
prob_batch_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
row = tl.load(
|
||||
softmax_logics
|
||||
+ current_head * logics_head_dim_stride
|
||||
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
other=-float("inf"),
|
||||
).to(tl.float32)
|
||||
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
tl.store(
|
||||
softmax_prob_out
|
||||
+ current_head * prob_head_dim_stride
|
||||
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
||||
softmax_output,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
|
||||
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
|
||||
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
|
||||
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
|
||||
_token_attn_softmax_fwd[(batch, head_num)](
|
||||
softmax_logics,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
softmax_prob_out,
|
||||
softmax_logics.stride(0),
|
||||
softmax_logics.stride(1),
|
||||
softmax_prob_out.stride(0),
|
||||
softmax_prob_out.stride(1),
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_2_kernel(
|
||||
Prob,
|
||||
V,
|
||||
attn_out,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
prob_head_dim_stride,
|
||||
prob_batch_stride,
|
||||
v_batch_stride,
|
||||
v_head_stride,
|
||||
v_head_dim_stride,
|
||||
attn_out_batch_stride,
|
||||
attn_out_head_stride,
|
||||
attn_out_head_dim_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
|
||||
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
|
||||
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
|
||||
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
for start_n in range(0, current_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
p_value = tl.load(
|
||||
Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
v_loc = tl.load(
|
||||
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
v_value = tl.load(
|
||||
V + v_offs + v_loc[:, None] * v_batch_stride,
|
||||
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16)
|
||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
||||
|
||||
acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8)
|
||||
off_o = (
|
||||
current_batch * attn_out_batch_stride
|
||||
+ current_head * attn_out_head_stride
|
||||
+ offs_d * attn_out_head_dim_stride
|
||||
)
|
||||
out_ptrs = attn_out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_fwd_2(
|
||||
prob,
|
||||
v,
|
||||
attn_out,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
):
|
||||
if triton.__version__ >= "2.1.0":
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
batch, head = kv_cache_loc.shape[0], v.shape[1]
|
||||
grid = (batch, head)
|
||||
num_warps = 4
|
||||
dim = v.shape[-1]
|
||||
|
||||
_token_attn_2_kernel[grid](
|
||||
prob,
|
||||
v,
|
||||
attn_out,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
prob.stride(0),
|
||||
prob.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
attn_out.stride(2),
|
||||
HEAD_DIM=dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_token_attention_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_out,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=None,
|
||||
):
|
||||
head_num = k.shape[1]
|
||||
batch_size = kv_cache_seq_len.shape[0]
|
||||
calcu_shape1 = (batch_size, head_num, k.shape[2])
|
||||
total_token_num = k.shape[0]
|
||||
|
||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda")
|
||||
|
||||
token_attn_fwd_1(
|
||||
q.view(calcu_shape1),
|
||||
k,
|
||||
att_m_tensor,
|
||||
q_input_scale,
|
||||
k_input_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=alibi,
|
||||
)
|
||||
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
|
||||
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
token_attn_fwd_2(
|
||||
prob,
|
||||
v,
|
||||
attn_out.view(calcu_shape1),
|
||||
v_input_scale,
|
||||
pv_output_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
prob = None
|
||||
|
||||
return
|
|
@ -0,0 +1,69 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
|
||||
|
||||
def build_model_and_tokenizer(model_name):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512)
|
||||
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
|
||||
model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs)
|
||||
model = model.to(torch.float32)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-name", type=str, help="model name")
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
help="where to save the checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
help="location of the calibration dataset",
|
||||
)
|
||||
parser.add_argument("--num-samples", type=int, default=512)
|
||||
parser.add_argument("--seq-len", type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = parse_args()
|
||||
model_path = args.model_name
|
||||
dataset_path = args.dataset_path
|
||||
output_path = args.output_path
|
||||
num_samples = 10
|
||||
seq_len = 512
|
||||
|
||||
model, tokenizer = build_model_and_tokenizer(model_path)
|
||||
if not os.path.exists(dataset_path):
|
||||
print(f"Cannot find the dataset at {args.dataset_path}")
|
||||
raise FileNotFoundError
|
||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||
|
||||
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
|
||||
model = model.cuda()
|
||||
|
||||
model.save_quantized(output_path, model_basename="llama-7b")
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
|
||||
model = model.cuda()
|
||||
|
||||
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
|
||||
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
|
||||
out = model.generate(**input_tokens, **generate_kwargs)
|
||||
text = tokenizer.batch_decode(out)
|
||||
print("out is:", text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads
|
||||
|
||||
|
||||
class SmoothquantBuilder(Builder):
|
||||
NAME = "cu_smoothquant"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in [
|
||||
"smoothquant/binding.cpp",
|
||||
"smoothquant/linear.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10
|
||||
|
||||
extra_cuda_flags = [
|
||||
"-v",
|
||||
f"-DCUDA_ARCH={cuda_arch}",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||
]
|
||||
|
||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
||||
|
||||
def builder(self):
|
||||
try:
|
||||
super().builder()
|
||||
except:
|
||||
warnings.warn("build smoothquant lib not successful")
|
|
@ -0,0 +1,136 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
"""
|
||||
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
"""
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.0] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
|
||||
)
|
||||
def test_llama_context_attention():
|
||||
head_num = 2
|
||||
seq_len = 32
|
||||
head_dim = 64
|
||||
dtype = torch.float
|
||||
hidden_size = head_num * head_dim
|
||||
|
||||
smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)
|
||||
|
||||
smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)
|
||||
|
||||
qkv_weight_scale = 1.0
|
||||
|
||||
ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")
|
||||
|
||||
smooth_attn = smooth_attn.to("cuda")
|
||||
|
||||
input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
|
||||
input_scale = 1 / 20.0
|
||||
|
||||
output = torch.matmul(input.to(torch.float) * input_scale, ones)
|
||||
qkv_max_out = torch.max(torch.abs(output)) / 127
|
||||
smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
|
||||
q = smooth_attn.q_proj(input)
|
||||
k = smooth_attn.k_proj(input)
|
||||
v = smooth_attn.v_proj(input)
|
||||
|
||||
cos_shape = (seq_len, head_dim // 2)
|
||||
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
|
||||
in_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||
out_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||
int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||
|
||||
q = q.to(torch.float) * out_scale
|
||||
k = k.to(torch.float) * out_scale
|
||||
v = v.to(torch.float) * out_scale
|
||||
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
|
||||
attn_out_max = torch.max(torch.abs(torch_out)) / 127
|
||||
|
||||
output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
|
||||
smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.k_output_scale = torch.tensor(qkv_max_out)
|
||||
|
||||
smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||
|
||||
smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
|
||||
smooth_attn.out_proj.a = torch.tensor([attn_out_max])
|
||||
|
||||
torch_out = (
|
||||
(torch_out / smooth_attn.attn_output_scale)
|
||||
.round()
|
||||
.clamp(-128, 127)
|
||||
.to(torch.int8)
|
||||
.view(-1, seq_len, head_num * head_dim)
|
||||
)
|
||||
|
||||
torch_out = smooth_attn.out_proj(torch_out)
|
||||
torch_out = torch_out.to(torch.float)
|
||||
|
||||
smooth_attn = smooth_attn.to("cuda")
|
||||
smooth_out, _, _ = smooth_attn(input, (cos, sin))
|
||||
smooth_out = smooth_out.to(torch.float)
|
||||
|
||||
assert torch.allclose(
|
||||
torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
|
||||
), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_context_attention()
|
|
@ -0,0 +1,84 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except:
|
||||
warnings.warn("CUDA smoothquant linear is not installed")
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
|
||||
|
||||
try:
|
||||
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except:
|
||||
HAS_TORCH_INT = False
|
||||
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
|
||||
gate_out = torch.mm(x, gate_proj)
|
||||
silu = torch.nn.SiLU()
|
||||
gate_out = silu(gate_out)
|
||||
up_out = torch.mm(x, up_proj)
|
||||
|
||||
o_out = gate_out * up_out
|
||||
|
||||
max_up = torch.max(torch.abs(o_out))
|
||||
min_up = torch.min(torch.abs(o_out))
|
||||
|
||||
torch_out = torch.mm(o_out, down_proj)
|
||||
|
||||
return (torch_out, max_up, min_up)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
|
||||
reason="smoothquant linear not installed properly or not install torch_int",
|
||||
)
|
||||
def test_llama_mlp():
|
||||
hidden_size = 256
|
||||
intermediate_size = 512
|
||||
|
||||
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)
|
||||
|
||||
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")
|
||||
|
||||
smooth_mlp.up_proj.weight = torch.randint(
|
||||
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
smooth_mlp.down_proj.weight = torch.randint(
|
||||
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
|
||||
x = torch.ones((1, 256), dtype=torch.int8, device="cuda")
|
||||
|
||||
torch_out, max_inter, min_inter = torch_llama_mlp(
|
||||
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
|
||||
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||
x.to(torch.float),
|
||||
)
|
||||
|
||||
smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127)
|
||||
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
|
||||
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
|
||||
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))
|
||||
|
||||
smooth_out = smooth_mlp(x)
|
||||
|
||||
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_mlp()
|
|
@ -0,0 +1,39 @@
|
|||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except:
|
||||
warnings.warn("CUDA smoothquant linear is not installed")
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not HAS_SMOOTHQUANT_CUDA,
|
||||
reason="smoothquant linear not installed properly",
|
||||
)
|
||||
def test_linear():
|
||||
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
|
||||
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
|
||||
c = torch.rand(256, dtype=torch.float, device="cuda")
|
||||
|
||||
alpha = 1 / 127
|
||||
beta = 1.0
|
||||
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
|
||||
|
||||
silu = torch.nn.SiLU()
|
||||
torch_out = silu(torch_out)
|
||||
|
||||
b = b.transpose(0, 1).contiguous()
|
||||
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
|
||||
|
||||
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_linear()
|
|
@ -0,0 +1,59 @@
|
|||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0 : dim // 2]
|
||||
x1 = x[:, :, dim // 2 : dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_rotary_emb():
|
||||
SEQ_LEN = 1
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
dtype = torch.float
|
||||
# create data
|
||||
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
# forward pass
|
||||
y_torch = torch_rotary_emb(x, cos, sin)
|
||||
|
||||
input_scale = torch.max(torch.abs(x)) / 127
|
||||
output_scale = torch.max(torch.abs(y_torch)) / 127
|
||||
|
||||
x = x / input_scale
|
||||
x = x.to(torch.int8)
|
||||
|
||||
int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item())
|
||||
y_triton = x.to(torch.float) * output_scale
|
||||
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb()
|
Loading…
Reference in New Issue