ColossalAI/colossalai/inference/quant/smoothquant/models/base_model.py

495 lines
20 KiB
Python

# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
import os
import warnings
from abc import abstractmethod
from functools import partial
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import transformers
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file
from colossalai.inference.kv_cache.batch_infer_state import BatchInferState, MemoryManager
try:
import accelerate
HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
print("accelerate is not installed.")
SUPPORTED_MODELS = ["llama"]
class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__()
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.config = self.model.config
self.cache_manager = None
self.max_total_token_num = 0
@property
def quantized(self):
return self._quantized
def init_cache_manager(self, max_total_token_num=2048):
if self.config.model_type == "llama":
head_num = self.config.num_key_value_heads
layer_num = self.config.num_hidden_layers
head_dim = self.config.hidden_size // head_num
self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
self.max_total_token_num = max_total_token_num
def init_batch_state(self, max_output_len=256, **kwargs):
input_ids = kwargs["input_ids"]
batch_size = len(input_ids)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
for i in range(batch_size):
seq_len = len(input_ids[i])
seq_lengths[i] = seq_len
seq_start_indexes[i] = start_index
start_index += seq_len
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
if "max_total_token_num" in kwargs.keys():
max_total_token_num = kwargs["max_total_token_num"]
self.init_cache_manager(max_total_token_num)
if "max_new_tokens" in kwargs.keys():
max_output_len = kwargs["max_new_tokens"]
if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
warnings.warn(f"reset max tokens to {max_total_token_num}")
self.init_cache_manager(max_total_token_num)
block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()
return batch_infer_state
@abstractmethod
@torch.inference_mode()
def quantize(
self,
examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, **kwargs):
"""shortcut for model.generate"""
batch_infer_state = self.init_batch_state(**kwargs)
if self.config.model_type == "llama":
setattr(self.model.model, "infer_state", batch_infer_state)
with torch.inference_mode():
return self.model.generate(**kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
for text in tqdm(dataset):
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
pbar = tqdm(dataset)
for text in pbar:
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}
def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
for h in hooks:
h.remove()
return act_scales
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
ln.weight.div_(scales)
if hasattr(ln, "bias"):
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
@classmethod
def create_quantized_model(model):
raise NotImplementedError("Not implement create_quantized_model method")
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_quantized(
self,
save_dir: str,
model_basename: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to("cpu")
model_base_name = model_basename # or f"smooth-"
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
print(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
print(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
print(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else:
model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
def save_pretrained(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs,
):
"""alias of save_quantized"""
warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: torch.dtype = torch.float16,
**model_init_kwargs,
):
if not torch.cuda.is_available():
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs.pop("cache_dir", None)
force_download = model_init_kwargs.pop("force_download", False)
resume_download = model_init_kwargs.pop("resume_download", False)
proxies = model_init_kwargs.pop("proxies", None)
local_files_only = model_init_kwargs.pop("local_files_only", False)
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
revision = model_init_kwargs.pop("revision", None)
subfolder = model_init_kwargs.pop("subfolder", "")
model_init_kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
}
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["torch_dtype"] = torch_dtype
model_init_kwargs["trust_remote_code"] = trust_remote_code
if max_memory:
if "disk" in max_memory:
raise NotImplementedError("disk offload not support yet.")
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False,
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
)
model_init_kwargs["low_cpu_mem_usage"] = True
del model
else:
model_init_kwargs["device_map"] = None
model_init_kwargs["low_cpu_mem_usage"] = False
torch.cuda.empty_cache()
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
return cls(model, False)
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
model_basename: Optional[str] = None,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
torch_dtype: Optional[torch.dtype] = None,
use_safetensors: bool = False,
trust_remote_code: bool = False,
**kwargs,
):
"""load quantized model from local disk"""
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
# == step1: prepare configs and file names == #
config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
extensions = []
if use_safetensors:
extensions.append(".safetensors")
else:
extensions += [".bin", ".pt"]
model_name_or_path = str(model_name_or_path)
is_local = isdir(model_name_or_path)
resolved_archive_file = None
if is_local:
model_save_name = join(model_name_or_path, model_basename)
for ext in extensions:
if isfile(model_save_name + ext):
resolved_archive_file = model_save_name + ext
break
else: # remote
for ext in extensions:
resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
if resolved_archive_file is not None:
break
if resolved_archive_file is None: # Could not find a model file to use
raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
model_save_name = resolved_archive_file
# == step2: convert model to quantized-model (replace Linear) == #
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
init_contexts = [no_init_weights()]
if low_cpu_mem_usage:
init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
with ContextManagers(init_contexts):
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
)
cls.create_quantized_model(model)
model.tie_weights()
# == step3: load checkpoint to quantized-model == #
accelerate.utils.modeling.load_checkpoint_in_model(
model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
)
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
warnings.warn("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
return cls(
model,
True,
)
def __getattr__(self, item):
try:
return super().__getattr__(item)
except:
return getattr(self.model, item)
__all__ = ["BaseSmoothForCausalLM"]