mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
329 lines
13 KiB
329 lines
13 KiB
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py |
|
|
|
import importlib.metadata |
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
from packaging.version import Version |
|
|
|
from .bnb_config import BnbQuantizationConfig |
|
|
|
try: |
|
import bitsandbytes as bnb |
|
|
|
try: |
|
# in case lower version of bitsandbytes does not have __version__ attribute |
|
BNB_VERSION = Version(bnb.__version__) |
|
except AttributeError: |
|
BNB_VERSION = Version(importlib.metadata.version("bitsandbytes")) |
|
|
|
IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0") |
|
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2") |
|
except ImportError: |
|
pass |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def quantize_model( |
|
model: torch.nn.Module, |
|
bnb_quantization_config: BnbQuantizationConfig, |
|
): |
|
""" |
|
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. |
|
We will quantize the model and put the model on the GPU. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
Input model. The model already loaded |
|
bnb_quantization_config (`BnbQuantizationConfig`): |
|
The bitsandbytes quantization parameters |
|
|
|
Returns: |
|
`torch.nn.Module`: The quantized model |
|
""" |
|
|
|
load_in_4bit = bnb_quantization_config.load_in_4bit |
|
load_in_8bit = bnb_quantization_config.load_in_8bit |
|
|
|
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: |
|
raise ImportError( |
|
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization," |
|
" make sure you have the latest version of `bitsandbytes` installed." |
|
) |
|
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: |
|
raise ValueError( |
|
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization," |
|
"make sure you have the latest version of `bitsandbytes` installed." |
|
) |
|
|
|
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons |
|
if bnb_quantization_config.skip_modules is None: |
|
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) |
|
|
|
modules_to_not_convert = bnb_quantization_config.skip_modules |
|
|
|
# We add the modules we want to keep in full precision |
|
if bnb_quantization_config.keep_in_fp32_modules is None: |
|
bnb_quantization_config.keep_in_fp32_modules = [] |
|
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules |
|
|
|
# compatibility with peft |
|
model.is_loaded_in_4bit = load_in_4bit |
|
model.is_loaded_in_8bit = load_in_8bit |
|
|
|
# assert model_device is cuda |
|
model_device = next(model.parameters()).device |
|
|
|
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) |
|
|
|
# convert param to the right dtype |
|
dtype = bnb_quantization_config.torch_dtype |
|
for name, param in model.state_dict().items(): |
|
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): |
|
param.to(torch.float32) |
|
if param.dtype != torch.float32: |
|
name = name.replace(".weight", "").replace(".bias", "") |
|
param = getattr(model, name, None) |
|
if param is not None: |
|
param.to(torch.float32) |
|
elif torch.is_floating_point(param): |
|
param.to(dtype) |
|
if model_device.type == "cuda": |
|
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda |
|
model.cuda(torch.cuda.current_device()) |
|
torch.cuda.empty_cache() |
|
elif torch.cuda.is_available(): |
|
model.to(torch.cuda.current_device()) |
|
logger.info( |
|
f"The model device type is {model_device.type}. However, cuda is needed for quantization." |
|
"We move the model to cuda." |
|
) |
|
else: |
|
raise RuntimeError("No GPU found. A GPU is needed for quantization.") |
|
return model |
|
|
|
|
|
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): |
|
""" |
|
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` |
|
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. |
|
|
|
Parameters: |
|
model (`torch.nn.Module`): |
|
Input model or `torch.nn.Module` as the function is run recursively. |
|
modules_to_not_convert (`List[str]`): |
|
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for |
|
numerical stability reasons. |
|
current_key_name (`List[str]`, *optional*): |
|
An array to track the current key of the recursion. This is used to check whether the current key (part of |
|
it) is not in the list of modules to not convert. |
|
""" |
|
|
|
if modules_to_not_convert is None: |
|
modules_to_not_convert = [] |
|
|
|
model, has_been_replaced = _replace_with_bnb_layers( |
|
model, bnb_quantization_config, modules_to_not_convert, current_key_name |
|
) |
|
if not has_been_replaced: |
|
logger.warning( |
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model." |
|
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." |
|
" Please double check your model architecture, or submit an issue on github if you think this is" |
|
" a bug." |
|
) |
|
return model |
|
|
|
|
|
def _replace_with_bnb_layers( |
|
model, |
|
bnb_quantization_config, |
|
modules_to_not_convert=None, |
|
current_key_name=None, |
|
): |
|
""" |
|
Private method that wraps the recursion for module replacement. |
|
|
|
Returns the converted model and a boolean that indicates if the conversion has been successfull or not. |
|
""" |
|
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily |
|
|
|
has_been_replaced = False |
|
for name, module in model.named_children(): |
|
if current_key_name is None: |
|
current_key_name = [] |
|
current_key_name.append(name) |
|
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: |
|
# Check if the current key is not in the `modules_to_not_convert` |
|
current_key_name_str = ".".join(current_key_name) |
|
proceed = True |
|
for key in modules_to_not_convert: |
|
if ( |
|
(key in current_key_name_str) and (key + "." in current_key_name_str) |
|
) or key == current_key_name_str: |
|
proceed = False |
|
break |
|
if proceed: |
|
# Load bnb module with empty weight and replace ``nn.Linear` module |
|
if bnb_quantization_config.load_in_8bit: |
|
bnb_module = bnb.nn.Linear8bitLt( |
|
module.in_features, |
|
module.out_features, |
|
module.bias is not None, |
|
has_fp16_weights=False, |
|
threshold=bnb_quantization_config.llm_int8_threshold, |
|
) |
|
elif bnb_quantization_config.load_in_4bit: |
|
bnb_module = bnb.nn.Linear4bit( |
|
module.in_features, |
|
module.out_features, |
|
module.bias is not None, |
|
bnb_quantization_config.bnb_4bit_compute_dtype, |
|
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, |
|
quant_type=bnb_quantization_config.bnb_4bit_quant_type, |
|
) |
|
else: |
|
raise ValueError("load_in_8bit and load_in_4bit can't be both False") |
|
bnb_module.weight.data = module.weight.data |
|
bnb_module.weight.skip_zero_check = True |
|
if module.bias is not None: |
|
bnb_module.bias.data = module.bias.data |
|
bnb_module.bias.skip_zero_check = True |
|
bnb_module.requires_grad_(False) |
|
setattr(model, name, bnb_module) |
|
has_been_replaced = True |
|
if len(list(module.children())) > 0: |
|
_, _has_been_replaced = _replace_with_bnb_layers( |
|
module, bnb_quantization_config, modules_to_not_convert, current_key_name |
|
) |
|
has_been_replaced = has_been_replaced | _has_been_replaced |
|
# Remove the last key for recursion |
|
current_key_name.pop(-1) |
|
return model, has_been_replaced |
|
|
|
|
|
def get_keys_to_not_convert(model): |
|
r""" |
|
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules |
|
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want |
|
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in |
|
int8. |
|
|
|
Parameters: |
|
model (`torch.nn.Module`): |
|
Input model |
|
""" |
|
# Create a copy of the model |
|
# with init_empty_weights(): |
|
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` |
|
tied_model = model |
|
|
|
tied_params = find_tied_parameters(tied_model) |
|
# For compatibility with Accelerate < 0.18 |
|
if isinstance(tied_params, dict): |
|
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) |
|
else: |
|
tied_keys = sum(tied_params, []) |
|
has_tied_params = len(tied_keys) > 0 |
|
|
|
# Check if it is a base model |
|
is_base_model = False |
|
if hasattr(model, "base_model_prefix"): |
|
is_base_model = not hasattr(model, model.base_model_prefix) |
|
|
|
# Ignore this for base models (BertModel, GPT2Model, etc.) |
|
if (not has_tied_params) and is_base_model: |
|
return [] |
|
|
|
# otherwise they have an attached head |
|
list_modules = list(model.named_children()) |
|
list_last_module = [list_modules[-1][0]] |
|
|
|
# add last module together with tied weights |
|
intersection = set(list_last_module) - set(tied_keys) |
|
list_untouched = list(set(tied_keys)) + list(intersection) |
|
|
|
# remove ".weight" from the keys |
|
names_to_remove = [".weight", ".bias"] |
|
filtered_module_names = [] |
|
for name in list_untouched: |
|
for name_to_remove in names_to_remove: |
|
if name_to_remove in name: |
|
name = name.replace(name_to_remove, "") |
|
filtered_module_names.append(name) |
|
|
|
return filtered_module_names |
|
|
|
|
|
def find_tied_parameters(model: nn.Module, **kwargs): |
|
""" |
|
Find the tied parameters in a given model. |
|
|
|
<Tip warning={true}> |
|
|
|
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore |
|
them. |
|
|
|
</Tip> |
|
|
|
Args: |
|
model (`torch.nn.Module`): The model to inspect. |
|
|
|
Returns: |
|
List[List[str]]: A list of lists of parameter names being all tied together. |
|
|
|
Example: |
|
|
|
```py |
|
>>> from collections import OrderedDict |
|
>>> import torch.nn as nn |
|
|
|
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) |
|
>>> model.linear2.weight = model.linear1.weight |
|
>>> find_tied_parameters(model) |
|
[['linear1.weight', 'linear2.weight']] |
|
``` |
|
""" |
|
# Initialize result and named_parameters before recursing. |
|
named_parameters = kwargs.get("named_parameters", None) |
|
prefix = kwargs.get("prefix", "") |
|
result = kwargs.get("result", {}) |
|
|
|
if named_parameters is None: |
|
named_parameters = {n: p for n, p in model.named_parameters()} |
|
else: |
|
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` |
|
# of the submodule it belongs to. So while recursing we track the names that are not in the initial |
|
# `named_parameters`. |
|
for name, parameter in model.named_parameters(): |
|
full_name = name if prefix == "" else f"{prefix}.{name}" |
|
if full_name not in named_parameters: |
|
# When we find one, it has to be one of the existing parameters. |
|
for new_name, new_param in named_parameters.items(): |
|
if new_param is parameter: |
|
if new_name not in result: |
|
result[new_name] = [] |
|
result[new_name].append(full_name) |
|
|
|
# Once we have treated direct parameters, we move to the child modules. |
|
for name, child in model.named_children(): |
|
child_name = name if prefix == "" else f"{prefix}.{name}" |
|
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) |
|
|
|
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) |
|
|
|
|
|
class FindTiedParametersResult(list): |
|
""" |
|
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not |
|
a list or on the `values` method as in the future this will be removed. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def values(self): |
|
return sum([x[1:] for x in self], [])
|
|
|