diff --git a/LICENSE b/LICENSE index 47197afe6..f0b2ffa97 100644 --- a/LICENSE +++ b/LICENSE @@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------- LICENSE FOR Hugging Face accelerate ---------------- + + Copyright 2021 The HuggingFace Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py index 327651f4e..abe0fd51a 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/loader.py @@ -80,15 +80,19 @@ class DataCollatorForSupervisedDataset(object): # `List[torch.Tensor]` batch_input_ids = [ - torch.LongTensor(instance["input_ids"][: self.max_length]) - if len(instance["input_ids"]) > self.max_length - else torch.LongTensor(instance["input_ids"]) + ( + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + ) for instance in instances ] batch_labels = [ - torch.LongTensor(instance["labels"][: self.max_length]) - if len(instance["labels"]) > self.max_length - else torch.LongTensor(instance["labels"]) + ( + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + ) for instance in instances ] diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index dcd7be9f4..37e4fcc80 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -253,9 +253,11 @@ def main() -> None: coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam( - model_params=filter(lambda p: p.requires_grad, model.parameters()) - if args.freeze_non_embeds_params - else model.parameters(), + model_params=( + filter(lambda p: p.requires_grad, model.parameters()) + if args.freeze_non_embeds_params + else model.parameters() + ), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay, diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index c2a724084..56d8a0935 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -19,6 +19,7 @@ except ImportError: import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -230,7 +231,12 @@ class Booster: return self.plugin.no_sync(model, optimizer) def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: "peft.LoraConfig" = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + quantize=False, ) -> nn.Module: """ Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. @@ -259,7 +265,20 @@ class Booster: assert ( pretrained_dir is not None ), "Please provide pretrained directory path if not passing in lora configuration." - return self.plugin.enable_lora(model, pretrained_dir, lora_config) + if quantize is True: + if bnb_quantization_config is not None: + warnings.warn( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + ) + else: + bnb_quantization_config = BnbQuantizationConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + + return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 6bc9ba0e7..be75bebac 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -28,6 +28,7 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -338,7 +339,11 @@ class LowLevelZeroPlugin(DPPluginBase): return True def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model @@ -346,6 +351,9 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = True warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + if pretrained_dir is None: peft_model = get_peft_model(model, lora_config) else: diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 9ba520de2..482cc4e98 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model from .dp_plugin_base import DPPluginBase @@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase): return model.module.no_sync() def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> nn.Module: from peft import PeftModel, get_peft_model + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." if pretrained_dir is None: return get_peft_model(model, lora_config) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 287853a86..c2b808155 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -165,7 +165,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc ##### Llama | batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | +|:-----------------------:|:------:|:------:|:------:| | hugging-face torch fp16 | 199.12 | 246.56 | 278.4 | | colossal-inference | 326.4 | 582.72 | 816.64 | @@ -174,7 +174,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc #### Bloom | batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | +|:-----------------------:|:------:|:------:|:------:| | hugging-face torch fp16 | 189.68 | 226.66 | 249.61 | | colossal-inference | 323.28 | 538.52 | 611.64 | @@ -187,40 +187,40 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t #### A10 7b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| -| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: | -| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM | -| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) | +|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM | ![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png) #### A10 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | -| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) | +|:----------------------------:|:-----:|:-----:|:-----:|:-----:| +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | ![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png) #### A800 7b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | -| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +|:----------------------------:|:-----:|:------:|:------:|:------:|:------:| +| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | ![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png) ### Quantization LLama -| batch_size | 8 | 16 | 32 | -| :---------------------: | :----: | :----: | :----: | -| auto-gptq | 199.20 | 232.56 | 253.26 | -| smooth-quant | 142.28 | 222.96 | 300.59 | -| colossal-gptq | 231.98 | 388.87 | 573.03 | +| batch_size | 8 | 16 | 32 | +|:-------------:|:------:|:------:|:------:| +| auto-gptq | 199.20 | 232.56 | 253.26 | +| smooth-quant | 142.28 | 222.96 | 300.59 | +| colossal-gptq | 231.98 | 388.87 | 573.03 | ![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png) diff --git a/colossalai/quantization/__init__.py b/colossalai/quantization/__init__.py new file mode 100644 index 000000000..e9707b479 --- /dev/null +++ b/colossalai/quantization/__init__.py @@ -0,0 +1,7 @@ +from .bnb import quantize_model +from .bnb_config import BnbQuantizationConfig + +__all__ = [ + "BnbQuantizationConfig", + "quantize_model", +] diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py new file mode 100644 index 000000000..fa214116a --- /dev/null +++ b/colossalai/quantization/bnb.py @@ -0,0 +1,321 @@ +# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py + +import logging + +import torch +import torch.nn as nn + +from .bnb_config import BnbQuantizationConfig + +try: + import bitsandbytes as bnb + + IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" + IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" +except ImportError: + pass + + +logger = logging.getLogger(__name__) + + +def quantize_model( + model: torch.nn.Module, + bnb_quantization_config: BnbQuantizationConfig, +): + """ + This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. + We will quantize the model and put the model on the GPU. + + Args: + model (`torch.nn.Module`): + Input model. The model already loaded + bnb_quantization_config (`BnbQuantizationConfig`): + The bitsandbytes quantization parameters + + Returns: + `torch.nn.Module`: The quantized model + """ + + load_in_4bit = bnb_quantization_config.load_in_4bit + load_in_8bit = bnb_quantization_config.load_in_8bit + + if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: + raise ImportError( + "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," + " make sure you have the latest version of `bitsandbytes` installed." + ) + if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," + "make sure you have the latest version of `bitsandbytes` installed." + ) + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if bnb_quantization_config.skip_modules is None: + bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) + + modules_to_not_convert = bnb_quantization_config.skip_modules + + # We add the modules we want to keep in full precision + if bnb_quantization_config.keep_in_fp32_modules is None: + bnb_quantization_config.keep_in_fp32_modules = [] + keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules + + # compatibility with peft + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + # assert model_device is cuda + model_device = next(model.parameters()).device + + model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) + + # convert param to the right dtype + dtype = bnb_quantization_config.torch_dtype + for name, param in model.state_dict().items(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.to(torch.float32) + if param.dtype != torch.float32: + name = name.replace(".weight", "").replace(".bias", "") + param = getattr(model, name, None) + if param is not None: + param.to(torch.float32) + elif torch.is_floating_point(param): + param.to(dtype) + if model_device.type == "cuda": + # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda + model.cuda(torch.cuda.current_device()) + torch.cuda.empty_cache() + elif torch.cuda.is_available(): + model.to(torch.cuda.current_device()) + logger.info( + f"The model device type is {model_device.type}. However, cuda is needed for quantization." + "We move the model to cuda." + ) + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + return model + + +def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` + modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[str]`): + Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for + numerical stability reasons. + current_key_name (`List[str]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert. + """ + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + model, has_been_replaced = _replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + return model + + +def _replace_with_bnb_layers( + model, + bnb_quantization_config, + modules_to_not_convert=None, + current_key_name=None, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily + + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + proceed = True + for key in modules_to_not_convert: + if ( + (key in current_key_name_str) and (key + "." in current_key_name_str) + ) or key == current_key_name_str: + proceed = False + break + if proceed: + # Load bnb module with empty weight and replace ``nn.Linear` module + if bnb_quantization_config.load_in_8bit: + bnb_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=bnb_quantization_config.llm_int8_threshold, + ) + elif bnb_quantization_config.load_in_4bit: + bnb_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + bnb_quantization_config.bnb_4bit_compute_dtype, + compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, + quant_type=bnb_quantization_config.bnb_4bit_quant_type, + ) + else: + raise ValueError("load_in_8bit and load_in_4bit can't be both False") + bnb_module.weight.data = module.weight.data + bnb_module.weight.skip_zero_check = True + if module.bias is not None: + bnb_module.bias.data = module.bias.data + bnb_module.bias.skip_zero_check = True + bnb_module.requires_grad_(False) + setattr(model, name, bnb_module) + has_been_replaced = True + if len(list(module.children())) > 0: + _, _has_been_replaced = _replace_with_bnb_layers( + module, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + has_been_replaced = has_been_replaced | _has_been_replaced + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model + # with init_empty_weights(): + # tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model = model + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = False + if hasattr(model, "base_model_prefix"): + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_children()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def find_tied_parameters(model: nn.Module, **kwargs): + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + # Initialize result and named_parameters before recursing. + named_parameters = kwargs.get("named_parameters", None) + prefix = kwargs.get("prefix", "") + result = kwargs.get("result", {}) + + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` + # of the submodule it belongs to. So while recursing we track the names that are not in the initial + # `named_parameters`. + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) + + +class FindTiedParametersResult(list): + """ + This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not + a list or on the `values` method as in the future this will be removed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def values(self): + return sum([x[1:] for x in self], []) diff --git a/colossalai/quantization/bnb_config.py b/colossalai/quantization/bnb_config.py new file mode 100644 index 000000000..98a30211b --- /dev/null +++ b/colossalai/quantization/bnb_config.py @@ -0,0 +1,113 @@ +# adapted from Hugging Face accelerate/utils/dataclasses.py + +import warnings +from dataclasses import dataclass, field +from typing import List + +import torch + + +@dataclass +class BnbQuantizationConfig: + """ + A plugin to enable BitsAndBytes 4bit and 8bit quantization + """ + + load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) + + llm_int8_threshold: float = field( + default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"} + ) + + load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) + + bnb_4bit_quant_type: str = field( + default="fp4", + metadata={ + "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}." + }, + ) + + bnb_4bit_use_double_quant: bool = field( + default=False, + metadata={ + "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." + }, + ) + + bnb_4bit_compute_dtype: bool = field( + default="fp16", + metadata={ + "help": "This sets the computational type which might be different than the input time. For example, inputs might be " + "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." + }, + ) + + torch_dtype: torch.dtype = field( + default=None, + metadata={ + "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" + "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " + }, + ) + + skip_modules: List[str] = field( + default=None, + metadata={ + "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." + }, + ) + + keep_in_fp32_modules: List[str] = field( + default=None, + metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, + ) + + def __post_init__(self): + if isinstance(self.bnb_4bit_compute_dtype, str): + if self.bnb_4bit_compute_dtype == "fp32": + self.bnb_4bit_compute_dtype = torch.float32 + elif self.bnb_4bit_compute_dtype == "fp16": + self.bnb_4bit_compute_dtype = torch.float16 + elif self.bnb_4bit_compute_dtype == "bf16": + self.bnb_4bit_compute_dtype = torch.bfloat16 + else: + raise ValueError( + f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" + ) + elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if self.skip_modules is not None and not isinstance(self.skip_modules, list): + raise ValueError("skip_modules must be a list of strings") + + if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): + raise ValueError("keep_in_fp_32_modules must be a list of strings") + + if self.load_in_4bit: + self.target_dtype = "int4" + + if self.load_in_8bit: + self.target_dtype = torch.int8 + + if self.load_in_4bit and self.llm_int8_threshold != 6.0: + warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") + + if isinstance(self.torch_dtype, str): + if self.torch_dtype == "fp32": + self.torch_dtype = torch.float32 + elif self.torch_dtype == "fp16": + self.torch_dtype = torch.float16 + elif self.torch_dtype == "bf16": + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") + + if self.load_in_8bit and self.torch_dtype is None: + self.torch_dtype = torch.float16 + + if self.load_in_4bit and self.torch_dtype is None: + self.torch_dtype = self.bnb_4bit_compute_dtype + + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError("torch_dtype must be a torch.dtype") diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index cbcf72697..345dfde73 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -235,9 +235,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: - assert ( - param.dtype == self._dtype - ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" def _create_master_param_current_rank(self, param_list): # split each param evenly by world size diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 815b23fc7..8ab13c0ad 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -18,3 +18,4 @@ google protobuf transformers==4.36.2 peft>=0.7.1 +bitsandbytes>=0.39.0 diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py new file mode 100644 index 000000000..69febff38 --- /dev/null +++ b/tests/test_lora/test_lora.py @@ -0,0 +1,106 @@ +import copy +import os +from itertools import product + +import torch +from peft import LoraConfig +from torch import distributed as dist +from torch.optim import AdamW + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_checkpoint_io.utils import shared_tempdir + + +@clear_cache_before_run() +def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + model = model_fn() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_configs = [ + { + "lora_config": lora_config, + "quantize": False, + }, + { + "lora_config": lora_config, + "quantize": True, + }, + ] + for plugin, test_config in product(test_plugins, test_configs): + # checkpoint loaded model + model_save = model_fn() + model_load = copy.deepcopy(model_save) + + optimizer = AdamW(model.parameters(), lr=0.001) + criterion = loss_fn + + booster = Booster(plugin=plugin) + model_save = booster.enable_lora(model_save, **test_config) + model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion) + + with shared_tempdir() as tempdir: + lora_ckpt_path = os.path.join(tempdir, "ckpt") + booster.save_lora_as_pretrained(model_save, lora_ckpt_path) + dist.barrier() + + # The Lora checkpoint should be small in size + checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) + assert checkpoint_size_mb < 1 + + model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config) + model_load, _, _, _, _ = booster.boost(model_load) + + check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) + + # test fwd bwd correctness + test_model = model_load + model_copy = copy.deepcopy(model_load) + + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + + output = test_model(**data) + output = output_transform_fn(output) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()): + if "lora_" in n1: + # lora modules require gradients, thus updated + assert p1.requires_grad + assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + else: + if not p1.requires_grad: + torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) + + +def run_lora_test(): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_lora_test() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_lora(): + spawn(run_dist, 2) diff --git a/tests/test_lora/test_torch_ddp_lora.py b/tests/test_lora/test_torch_ddp_lora.py deleted file mode 100644 index b3169bf86..000000000 --- a/tests/test_lora/test_torch_ddp_lora.py +++ /dev/null @@ -1,108 +0,0 @@ -import copy -import os - -import torch -from peft import LoraConfig -from torch import distributed as dist -from torch.optim import AdamW - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.testing import ( - assert_equal, - assert_not_equal, - check_state_dict_equal, - clear_cache_before_run, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_checkpoint_io.utils import shared_tempdir - - -@clear_cache_before_run() -def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - model = model_fn() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - - model = booster.enable_lora(model, lora_config=lora_config) - model_copy = copy.deepcopy(model) - - optimizer = AdamW(model.parameters(), lr=0.001) - criterion = loss_fn - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} - - output = model(**data) - output = output_transform_fn(output) - loss = criterion(output) - - booster.backward(loss, optimizer) - optimizer.clip_grad_by_norm(1.0) - optimizer.step() - - for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()): - if "lora_" in n1: - # lora modules require gradients, thus updated - assert p1.requires_grad - assert_not_equal(p1.to(p2.device), p2) - else: - if not p1.requires_grad: - assert_equal(p1.to(p2.device), p2) - - -@clear_cache_before_run() -def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): - plugin = TorchDDPPlugin() - lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - - model_save = model_fn() - model_load = copy.deepcopy(model_save) - - booster = Booster(plugin=plugin) - model_save = booster.enable_lora(model_save, lora_config=lora_config) - model_save, _, _, _, _ = booster.boost(model_save) - - with shared_tempdir() as tempdir: - lora_ckpt_path = os.path.join(tempdir, "ckpt") - booster.save_lora_as_pretrained(model_save, lora_ckpt_path) - dist.barrier() - - # The Lora checkpoint should be small in size - checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) - assert checkpoint_size_mb < 1 - - model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path) - model_load, _, _, _, _ = booster.boost(model_load) - - check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) - - -def run_lora_test(): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - task_type = None - if name == "transformers_llama_for_casual_lm": - task_type = "CAUSAL_LM" - if name == "transformers_llama_for_sequence_classification": - task_type = "SEQ_CLS" - check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_lora_test() - - -@rerun_if_address_is_in_use() -def test_torch_ddp_lora(): - spawn(run_dist, 2)