# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py import logging import torch import torch.nn as nn from .bnb_config import BnbQuantizationConfig try: import bitsandbytes as bnb IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" except ImportError: pass logger = logging.getLogger(__name__) def quantize_model( model: torch.nn.Module, bnb_quantization_config: BnbQuantizationConfig, ): """ This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. We will quantize the model and put the model on the GPU. Args: model (`torch.nn.Module`): Input model. The model already loaded bnb_quantization_config (`BnbQuantizationConfig`): The bitsandbytes quantization parameters Returns: `torch.nn.Module`: The quantized model """ load_in_4bit = bnb_quantization_config.load_in_4bit load_in_8bit = bnb_quantization_config.load_in_8bit if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: raise ImportError( "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," " make sure you have the latest version of `bitsandbytes` installed." ) if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: raise ValueError( "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," "make sure you have the latest version of `bitsandbytes` installed." ) # We keep some modules such as the lm_head in their original dtype for numerical stability reasons if bnb_quantization_config.skip_modules is None: bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) modules_to_not_convert = bnb_quantization_config.skip_modules # We add the modules we want to keep in full precision if bnb_quantization_config.keep_in_fp32_modules is None: bnb_quantization_config.keep_in_fp32_modules = [] keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules # compatibility with peft model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_8bit = load_in_8bit # assert model_device is cuda model_device = next(model.parameters()).device model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) # convert param to the right dtype dtype = bnb_quantization_config.torch_dtype for name, param in model.state_dict().items(): if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): param.to(torch.float32) if param.dtype != torch.float32: name = name.replace(".weight", "").replace(".bias", "") param = getattr(model, name, None) if param is not None: param.to(torch.float32) elif torch.is_floating_point(param): param.to(dtype) if model_device.type == "cuda": # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda model.cuda(torch.cuda.current_device()) torch.cuda.empty_cache() elif torch.cuda.is_available(): model.to(torch.cuda.current_device()) logger.info( f"The model device type is {model_device.type}. However, cuda is needed for quantization." "We move the model to cuda." ) else: raise RuntimeError("No GPU found. A GPU is needed for quantization.") return model def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`List[str]`): Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for numerical stability reasons. current_key_name (`List[str]`, *optional*): An array to track the current key of the recursion. This is used to check whether the current key (part of it) is not in the list of modules to not convert. """ if modules_to_not_convert is None: modules_to_not_convert = [] model, has_been_replaced = _replace_with_bnb_layers( model, bnb_quantization_config, modules_to_not_convert, current_key_name ) if not has_been_replaced: logger.warning( "You are loading your model in 8bit or 4bit but no linear modules were found in your model." " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." " Please double check your model architecture, or submit an issue on github if you think this is" " a bug." ) return model def _replace_with_bnb_layers( model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None, ): """ Private method that wraps the recursion for module replacement. Returns the converted model and a boolean that indicates if the conversion has been successfull or not. """ # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily has_been_replaced = False for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if isinstance(module, nn.Linear) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` current_key_name_str = ".".join(current_key_name) proceed = True for key in modules_to_not_convert: if ( (key in current_key_name_str) and (key + "." in current_key_name_str) ) or key == current_key_name_str: proceed = False break if proceed: # Load bnb module with empty weight and replace ``nn.Linear` module if bnb_quantization_config.load_in_8bit: bnb_module = bnb.nn.Linear8bitLt( module.in_features, module.out_features, module.bias is not None, has_fp16_weights=False, threshold=bnb_quantization_config.llm_int8_threshold, ) elif bnb_quantization_config.load_in_4bit: bnb_module = bnb.nn.Linear4bit( module.in_features, module.out_features, module.bias is not None, bnb_quantization_config.bnb_4bit_compute_dtype, compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, quant_type=bnb_quantization_config.bnb_4bit_quant_type, ) else: raise ValueError("load_in_8bit and load_in_4bit can't be both False") bnb_module.weight.data = module.weight.data bnb_module.weight.skip_zero_check = True if module.bias is not None: bnb_module.bias.data = module.bias.data bnb_module.bias.skip_zero_check = True bnb_module.requires_grad_(False) setattr(model, name, bnb_module) has_been_replaced = True if len(list(module.children())) > 0: _, _has_been_replaced = _replace_with_bnb_layers( module, bnb_quantization_config, modules_to_not_convert, current_key_name ) has_been_replaced = has_been_replaced | _has_been_replaced # Remove the last key for recursion current_key_name.pop(-1) return model, has_been_replaced def get_keys_to_not_convert(model): r""" An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in int8. Parameters: model (`torch.nn.Module`): Input model """ # Create a copy of the model # with init_empty_weights(): # tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model = model tied_params = find_tied_parameters(tied_model) # For compatibility with Accelerate < 0.18 if isinstance(tied_params, dict): tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) else: tied_keys = sum(tied_params, []) has_tied_params = len(tied_keys) > 0 # Check if it is a base model is_base_model = False if hasattr(model, "base_model_prefix"): is_base_model = not hasattr(model, model.base_model_prefix) # Ignore this for base models (BertModel, GPT2Model, etc.) if (not has_tied_params) and is_base_model: return [] # otherwise they have an attached head list_modules = list(model.named_children()) list_last_module = [list_modules[-1][0]] # add last module together with tied weights intersection = set(list_last_module) - set(tied_keys) list_untouched = list(set(tied_keys)) + list(intersection) # remove ".weight" from the keys names_to_remove = [".weight", ".bias"] filtered_module_names = [] for name in list_untouched: for name_to_remove in names_to_remove: if name_to_remove in name: name = name.replace(name_to_remove, "") filtered_module_names.append(name) return filtered_module_names def find_tied_parameters(model: nn.Module, **kwargs): """ Find the tied parameters in a given model. The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore them. Args: model (`torch.nn.Module`): The model to inspect. Returns: List[List[str]]: A list of lists of parameter names being all tied together. Example: ```py >>> from collections import OrderedDict >>> import torch.nn as nn >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) >>> model.linear2.weight = model.linear1.weight >>> find_tied_parameters(model) [['linear1.weight', 'linear2.weight']] ``` """ # Initialize result and named_parameters before recursing. named_parameters = kwargs.get("named_parameters", None) prefix = kwargs.get("prefix", "") result = kwargs.get("result", {}) if named_parameters is None: named_parameters = {n: p for n, p in model.named_parameters()} else: # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` # of the submodule it belongs to. So while recursing we track the names that are not in the initial # `named_parameters`. for name, parameter in model.named_parameters(): full_name = name if prefix == "" else f"{prefix}.{name}" if full_name not in named_parameters: # When we find one, it has to be one of the existing parameters. for new_name, new_param in named_parameters.items(): if new_param is parameter: if new_name not in result: result[new_name] = [] result[new_name].append(full_name) # Once we have treated direct parameters, we move to the child modules. for name, child in model.named_children(): child_name = name if prefix == "" else f"{prefix}.{name}" find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) class FindTiedParametersResult(list): """ This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not a list or on the `values` method as in the future this will be removed. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def values(self): return sum([x[1:] for x in self], [])