2023-10-16 03:28:44 +00:00
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
import os
import warnings
from abc import abstractmethod
from functools import partial
from os . path import isdir , isfile , join
from typing import Dict , List , Optional , Union
import numpy as np
import torch
import torch . nn as nn
import transformers
from safetensors . torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig , AutoModelForCausalLM , PreTrainedModel
from transformers . modeling_utils import no_init_weights
from transformers . utils . generic import ContextManagers
from transformers . utils . hub import PushToHubMixin , cached_file
2023-11-19 13:05:05 +00:00
from colossalai . inference . kv_cache . batch_infer_state import BatchInferState , MemoryManager
try :
import accelerate
HAS_ACCELERATE = True
except ImportError :
HAS_ACCELERATE = False
print ( " accelerate is not installed. " )
2023-10-16 03:28:44 +00:00
SUPPORTED_MODELS = [ " llama " ]
class BaseSmoothForCausalLM ( nn . Module , PushToHubMixin ) :
layer_type : str = None
def __init__ ( self , model : PreTrainedModel , quantized : bool = False ) :
super ( ) . __init__ ( )
self . model = model
self . model_type = self . model . config . model_type
self . _quantized = quantized
self . config = self . model . config
self . cache_manager = None
self . max_total_token_num = 0
@property
def quantized ( self ) :
return self . _quantized
def init_cache_manager ( self , max_total_token_num = 2048 ) :
if self . config . model_type == " llama " :
head_num = self . config . num_key_value_heads
layer_num = self . config . num_hidden_layers
head_dim = self . config . hidden_size / / head_num
self . cache_manager = MemoryManager ( max_total_token_num , torch . int8 , head_num , head_dim , layer_num )
self . max_total_token_num = max_total_token_num
def init_batch_state ( self , max_output_len = 256 , * * kwargs ) :
input_ids = kwargs [ " input_ids " ]
batch_size = len ( input_ids )
seq_start_indexes = torch . zeros ( batch_size , dtype = torch . int32 , device = " cuda " )
seq_lengths = torch . zeros ( batch_size , dtype = torch . int32 , device = " cuda " )
start_index = 0
max_len_in_batch = - 1
for i in range ( batch_size ) :
seq_len = len ( input_ids [ i ] )
seq_lengths [ i ] = seq_len
seq_start_indexes [ i ] = start_index
start_index + = seq_len
max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
if " max_total_token_num " in kwargs . keys ( ) :
max_total_token_num = kwargs [ " max_total_token_num " ]
self . init_cache_manager ( max_total_token_num )
if " max_new_tokens " in kwargs . keys ( ) :
max_output_len = kwargs [ " max_new_tokens " ]
if batch_size * ( max_len_in_batch + max_output_len ) > self . max_total_token_num :
max_total_token_num = batch_size * ( max_len_in_batch + max_output_len )
warnings . warn ( f " reset max tokens to { max_total_token_num } " )
self . init_cache_manager ( max_total_token_num )
block_loc = torch . empty ( ( batch_size , max_len_in_batch + max_output_len ) , dtype = torch . long , device = " cuda " )
batch_infer_state = BatchInferState ( batch_size , max_len_in_batch )
batch_infer_state . seq_len = seq_lengths . to ( " cuda " )
batch_infer_state . start_loc = seq_start_indexes . to ( " cuda " )
batch_infer_state . block_loc = block_loc
batch_infer_state . decode_layer_id = 0
batch_infer_state . is_context_stage = True
batch_infer_state . set_cache_manager ( self . cache_manager )
batch_infer_state . cache_manager . free_all ( )
return batch_infer_state
@abstractmethod
@torch.inference_mode ( )
def quantize (
self ,
examples : List [ Dict [ str , Union [ List [ int ] , torch . LongTensor ] ] ] ,
) :
if self . quantized :
raise EnvironmentError ( " can ' t execute quantize because the model is quantized. " )
def forward ( self , * args , * * kwargs ) :
return self . model ( * args , * * kwargs )
def generate ( self , * * kwargs ) :
""" shortcut for model.generate """
batch_infer_state = self . init_batch_state ( * * kwargs )
if self . config . model_type == " llama " :
setattr ( self . model . model , " infer_state " , batch_infer_state )
with torch . inference_mode ( ) :
return self . model . generate ( * * kwargs )
def prepare_inputs_for_generation ( self , * args , * * kwargs ) :
""" shortcut for model.prepare_inputs_for_generation """
return self . model . prepare_inputs_for_generation ( * args , * * kwargs )
def collect_act_scales ( self , model , tokenizer , dataset , device , num_samples = 512 , seq_len = 512 ) :
for text in tqdm ( dataset ) :
input_ids = tokenizer ( text , return_tensors = " pt " , max_length = seq_len , truncation = True ) . input_ids . to ( device )
model ( input_ids )
def collect_act_dict ( self , model , tokenizer , dataset , act_dict , device , num_samples = 512 , seq_len = 512 ) :
pbar = tqdm ( dataset )
for text in pbar :
input_ids = tokenizer ( text , return_tensors = " pt " , max_length = seq_len , truncation = True ) . input_ids . to ( device )
model ( input_ids )
mean_scale = np . mean ( [ v [ " input " ] for v in act_dict . values ( ) ] )
pbar . set_description ( f " Mean input scale: { mean_scale : .2f } " )
2023-10-20 05:39:34 +00:00
# Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
2023-10-16 03:28:44 +00:00
def get_act_scales ( self , model , tokenizer , dataset , num_samples = 512 , seq_len = 512 ) :
model . eval ( )
device = next ( model . parameters ( ) ) . device
act_scales = { }
def stat_tensor ( name , tensor ) :
hidden_dim = tensor . shape [ - 1 ]
tensor = tensor . view ( - 1 , hidden_dim ) . abs ( ) . detach ( )
comming_max = torch . max ( tensor , dim = 0 ) [ 0 ] . float ( ) . cpu ( )
if name in act_scales :
act_scales [ name ] = torch . max ( act_scales [ name ] , comming_max )
else :
act_scales [ name ] = comming_max
def stat_input_hook ( m , x , y , name ) :
if isinstance ( x , tuple ) :
x = x [ 0 ]
stat_tensor ( name , x )
hooks = [ ]
for name , m in model . named_modules ( ) :
if isinstance ( m , nn . Linear ) :
hooks . append ( m . register_forward_hook ( partial ( stat_input_hook , name = name ) ) )
self . collect_act_scales ( model , tokenizer , dataset , device , num_samples , seq_len )
for h in hooks :
h . remove ( )
return act_scales
2023-10-20 05:39:34 +00:00
# Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
2023-10-16 03:28:44 +00:00
@torch.no_grad ( )
def smooth_ln_fcs ( self , ln , fcs , act_scales , alpha = 0.5 ) :
if not isinstance ( fcs , list ) :
fcs = [ fcs ]
for fc in fcs :
assert isinstance ( fc , nn . Linear )
assert ln . weight . numel ( ) == fc . in_features == act_scales . numel ( )
device , dtype = fcs [ 0 ] . weight . device , fcs [ 0 ] . weight . dtype
act_scales = act_scales . to ( device = device , dtype = dtype )
weight_scales = torch . cat ( [ fc . weight . abs ( ) . max ( dim = 0 , keepdim = True ) [ 0 ] for fc in fcs ] , dim = 0 )
weight_scales = weight_scales . max ( dim = 0 ) [ 0 ] . clamp ( min = 1e-5 )
scales = ( act_scales . pow ( alpha ) / weight_scales . pow ( 1 - alpha ) ) . clamp ( min = 1e-5 ) . to ( device ) . to ( dtype )
ln . weight . div_ ( scales )
if hasattr ( ln , " bias " ) :
ln . bias . div_ ( scales )
for fc in fcs :
fc . weight . mul_ ( scales . view ( 1 , - 1 ) )
@classmethod
def create_quantized_model ( model ) :
raise NotImplementedError ( " Not implement create_quantized_model method " )
2023-10-20 05:39:34 +00:00
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
2023-10-16 03:28:44 +00:00
def save_quantized (
self ,
save_dir : str ,
model_basename : str ,
use_safetensors : bool = False ,
safetensors_metadata : Optional [ Dict [ str , str ] ] = None ,
) :
""" save quantized model and configs to local disk """
os . makedirs ( save_dir , exist_ok = True )
if not self . quantized :
raise EnvironmentError ( " can only save quantized model, please execute .quantize first. " )
self . model . to ( " cpu " )
model_base_name = model_basename # or f"smooth-"
if use_safetensors :
model_save_name = model_base_name + " .safetensors "
state_dict = self . model . state_dict ( )
state_dict = { k : v . clone ( ) . contiguous ( ) for k , v in state_dict . items ( ) }
if safetensors_metadata is None :
safetensors_metadata = { }
elif not isinstance ( safetensors_metadata , dict ) :
raise TypeError ( " safetensors_metadata must be a dictionary. " )
else :
print ( f " Received safetensors_metadata: { safetensors_metadata } " )
new_safetensors_metadata = { }
converted_keys = False
for key , value in safetensors_metadata . items ( ) :
if not isinstance ( key , str ) or not isinstance ( value , str ) :
converted_keys = True
try :
new_key = str ( key )
new_value = str ( value )
except Exception as e :
raise TypeError (
f " safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: { e } "
)
if new_key in new_safetensors_metadata :
print (
f " After converting safetensors_metadata keys to strings, the key ' { new_key } ' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting. "
)
new_safetensors_metadata [ new_key ] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys :
print (
f " One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: { safetensors_metadata } "
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata [ " format " ] = " pt "
safe_save ( state_dict , join ( save_dir , model_save_name ) , safetensors_metadata )
else :
model_save_name = model_base_name + " .bin "
torch . save ( self . model . state_dict ( ) , join ( save_dir , model_save_name ) )
self . model . config . save_pretrained ( save_dir )
2023-10-20 05:39:34 +00:00
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
2023-10-16 03:28:44 +00:00
def save_pretrained (
self ,
save_dir : str ,
use_safetensors : bool = False ,
safetensors_metadata : Optional [ Dict [ str , str ] ] = None ,
* * kwargs ,
) :
""" alias of save_quantized """
warnings . warn ( " you are using save_pretrained, which will re-direct to save_quantized. " )
self . save_quantized ( save_dir , use_safetensors , safetensors_metadata )
2023-10-20 05:39:34 +00:00
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
2023-10-16 03:28:44 +00:00
@classmethod
def from_pretrained (
cls ,
pretrained_model_name_or_path : str ,
max_memory : Optional [ dict ] = None ,
trust_remote_code : bool = False ,
torch_dtype : torch . dtype = torch . float16 ,
* * model_init_kwargs ,
) :
if not torch . cuda . is_available ( ) :
raise EnvironmentError ( " Load pretrained model to do quantization requires CUDA available. " )
def skip ( * args , * * kwargs ) :
pass
torch . nn . init . kaiming_uniform_ = skip
torch . nn . init . uniform_ = skip
torch . nn . init . normal_ = skip
# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs . pop ( " cache_dir " , None )
force_download = model_init_kwargs . pop ( " force_download " , False )
resume_download = model_init_kwargs . pop ( " resume_download " , False )
proxies = model_init_kwargs . pop ( " proxies " , None )
local_files_only = model_init_kwargs . pop ( " local_files_only " , False )
use_auth_token = model_init_kwargs . pop ( " use_auth_token " , None )
revision = model_init_kwargs . pop ( " revision " , None )
subfolder = model_init_kwargs . pop ( " subfolder " , " " )
model_init_kwargs . pop ( " _commit_hash " , None )
cached_file_kwargs = {
" cache_dir " : cache_dir ,
" force_download " : force_download ,
" proxies " : proxies ,
" resume_download " : resume_download ,
" local_files_only " : local_files_only ,
" use_auth_token " : use_auth_token ,
" revision " : revision ,
" subfolder " : subfolder ,
}
config = AutoConfig . from_pretrained ( pretrained_model_name_or_path , trust_remote_code = True , * * cached_file_kwargs )
if config . model_type not in SUPPORTED_MODELS :
raise TypeError ( f " { config . model_type } isn ' t supported yet. " )
# enforce some values despite user specified
model_init_kwargs [ " torch_dtype " ] = torch_dtype
model_init_kwargs [ " trust_remote_code " ] = trust_remote_code
if max_memory :
if " disk " in max_memory :
raise NotImplementedError ( " disk offload not support yet. " )
with accelerate . init_empty_weights ( ) :
model = AutoModelForCausalLM . from_config ( config , trust_remote_code = True )
model . tie_weights ( )
max_memory = accelerate . utils . get_balanced_memory (
model ,
max_memory = max_memory ,
no_split_module_classes = [ cls . layer_type ] ,
dtype = model_init_kwargs [ " torch_dtype " ] ,
low_zero = False ,
)
model_init_kwargs [ " device_map " ] = accelerate . infer_auto_device_map (
model ,
max_memory = max_memory ,
no_split_module_classes = [ cls . layer_type ] ,
dtype = model_init_kwargs [ " torch_dtype " ] ,
)
model_init_kwargs [ " low_cpu_mem_usage " ] = True
del model
else :
model_init_kwargs [ " device_map " ] = None
model_init_kwargs [ " low_cpu_mem_usage " ] = False
torch . cuda . empty_cache ( )
merged_kwargs = { * * model_init_kwargs , * * cached_file_kwargs }
model = AutoModelForCausalLM . from_pretrained ( pretrained_model_name_or_path , * * merged_kwargs )
model_config = model . config . to_dict ( )
seq_len_keys = [ " max_position_embeddings " , " seq_length " , " n_positions " ]
if any ( [ k in model_config for k in seq_len_keys ] ) :
for key in seq_len_keys :
if key in model_config :
model . seqlen = model_config [ key ]
break
else :
warnings . warn ( " can ' t get model ' s sequence length from model config, will set to 4096. " )
model . seqlen = 4096
model . eval ( )
return cls ( model , False )
2023-10-20 05:39:34 +00:00
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
2023-10-16 03:28:44 +00:00
@classmethod
def from_quantized (
cls ,
model_name_or_path : Optional [ str ] ,
model_basename : Optional [ str ] = None ,
device_map : Optional [ Union [ str , Dict [ str , Union [ int , str ] ] ] ] = None ,
max_memory : Optional [ dict ] = None ,
device : Optional [ Union [ str , int ] ] = None ,
low_cpu_mem_usage : bool = False ,
torch_dtype : Optional [ torch . dtype ] = None ,
use_safetensors : bool = False ,
trust_remote_code : bool = False ,
* * kwargs ,
) :
""" load quantized model from local disk """
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs . pop ( " cache_dir " , None )
force_download = kwargs . pop ( " force_download " , False )
resume_download = kwargs . pop ( " resume_download " , False )
proxies = kwargs . pop ( " proxies " , None )
local_files_only = kwargs . pop ( " local_files_only " , False )
use_auth_token = kwargs . pop ( " use_auth_token " , None )
revision = kwargs . pop ( " revision " , None )
subfolder = kwargs . pop ( " subfolder " , " " )
commit_hash = kwargs . pop ( " _commit_hash " , None )
cached_file_kwargs = {
" cache_dir " : cache_dir ,
" force_download " : force_download ,
" proxies " : proxies ,
" resume_download " : resume_download ,
" local_files_only " : local_files_only ,
" use_auth_token " : use_auth_token ,
" revision " : revision ,
" subfolder " : subfolder ,
" _raise_exceptions_for_missing_entries " : False ,
" _commit_hash " : commit_hash ,
}
# == step1: prepare configs and file names == #
config = AutoConfig . from_pretrained (
model_name_or_path , trust_remote_code = trust_remote_code , * * cached_file_kwargs
)
if config . model_type not in SUPPORTED_MODELS :
raise TypeError ( f " { config . model_type } isn ' t supported yet. " )
extensions = [ ]
if use_safetensors :
extensions . append ( " .safetensors " )
else :
extensions + = [ " .bin " , " .pt " ]
model_name_or_path = str ( model_name_or_path )
is_local = isdir ( model_name_or_path )
resolved_archive_file = None
if is_local :
model_save_name = join ( model_name_or_path , model_basename )
for ext in extensions :
if isfile ( model_save_name + ext ) :
resolved_archive_file = model_save_name + ext
break
else : # remote
for ext in extensions :
resolved_archive_file = cached_file ( model_name_or_path , model_basename + ext , * * cached_file_kwargs )
if resolved_archive_file is not None :
break
if resolved_archive_file is None : # Could not find a model file to use
raise FileNotFoundError ( f " Could not find model in { model_name_or_path } " )
model_save_name = resolved_archive_file
# == step2: convert model to quantized-model (replace Linear) == #
def skip ( * args , * * kwargs ) :
pass
torch . nn . init . kaiming_uniform_ = skip
torch . nn . init . uniform_ = skip
torch . nn . init . normal_ = skip
transformers . modeling_utils . _init_weights = False
init_contexts = [ no_init_weights ( ) ]
if low_cpu_mem_usage :
init_contexts . append ( accelerate . init_empty_weights ( include_buffers = True ) )
with ContextManagers ( init_contexts ) :
model = AutoModelForCausalLM . from_config (
config , trust_remote_code = trust_remote_code , torch_dtype = torch_dtype
)
cls . create_quantized_model ( model )
model . tie_weights ( )
# == step3: load checkpoint to quantized-model == #
accelerate . utils . modeling . load_checkpoint_in_model (
model , checkpoint = model_save_name , offload_state_dict = True , offload_buffers = True
)
# == step4: set seqlen == #
model_config = model . config . to_dict ( )
seq_len_keys = [ " max_position_embeddings " , " seq_length " , " n_positions " ]
if any ( [ k in model_config for k in seq_len_keys ] ) :
for key in seq_len_keys :
if key in model_config :
model . seqlen = model_config [ key ]
break
else :
warnings . warn ( " can ' t get model ' s sequence length from model config, will set to 4096. " )
model . seqlen = 4096
return cls (
model ,
True ,
)
def __getattr__ ( self , item ) :
try :
return super ( ) . __getattr__ ( item )
except :
return getattr ( self . model , item )
__all__ = [ " BaseSmoothForCausalLM " ]