Browse Source

[Feature] qlora support (#5586)

* [feature] qlora support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qlora follow commit

* migrate qutization folder to colossalai/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5670/head
linsj20 7 months ago committed by Hongxin Liu
parent
commit
91fa553775
  1. 15
      LICENSE
  2. 16
      applications/Colossal-LLaMA/colossal_llama/dataset/loader.py
  3. 8
      applications/Colossal-LLaMA/train.py
  4. 23
      colossalai/booster/booster.py
  5. 10
      colossalai/booster/plugin/low_level_zero_plugin.py
  6. 10
      colossalai/booster/plugin/torch_ddp_plugin.py
  7. 38
      colossalai/inference/README.md
  8. 7
      colossalai/quantization/__init__.py
  9. 321
      colossalai/quantization/bnb.py
  10. 113
      colossalai/quantization/bnb_config.py
  11. 7
      colossalai/zero/low_level/low_level_optim.py
  12. 1
      requirements/requirements.txt
  13. 106
      tests/test_lora/test_lora.py
  14. 108
      tests/test_lora/test_torch_ddp_lora.py

15
LICENSE

@ -552,3 +552,18 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
---------------- LICENSE FOR Hugging Face accelerate ----------------
Copyright 2021 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

16
applications/Colossal-LLaMA/colossal_llama/dataset/loader.py

@ -80,15 +80,19 @@ class DataCollatorForSupervisedDataset(object):
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]

8
applications/Colossal-LLaMA/train.py

@ -253,9 +253,11 @@ def main() -> None:
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
optimizer = HybridAdam(
model_params=filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters(),
model_params=(
filter(lambda p: p.requires_grad, model.parameters())
if args.freeze_non_embeds_params
else model.parameters()
),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,

23
colossalai/booster/booster.py

@ -19,6 +19,7 @@ except ImportError:
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -230,7 +231,12 @@ class Booster:
return self.plugin.no_sync(model, optimizer)
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
quantize=False,
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
@ -259,7 +265,20 @@ class Booster:
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
)
else:
bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.

10
colossalai/booster/plugin/low_level_zero_plugin.py

@ -28,6 +28,7 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
@ -338,7 +339,11 @@ class LowLevelZeroPlugin(DPPluginBase):
return True
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
@ -346,6 +351,9 @@ class LowLevelZeroPlugin(DPPluginBase):
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:

10
colossalai/booster/plugin/torch_ddp_plugin.py

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from .dp_plugin_base import DPPluginBase
@ -237,10 +238,17 @@ class TorchDDPPlugin(DPPluginBase):
return model.module.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)

38
colossalai/inference/README.md

@ -165,7 +165,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
##### Llama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
|:-----------------------:|:------:|:------:|:------:|
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
| colossal-inference | 326.4 | 582.72 | 816.64 |
@ -174,7 +174,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
#### Bloom
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
|:-----------------------:|:------:|:------:|:------:|
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
| colossal-inference | 323.28 | 538.52 | 611.64 |
@ -187,40 +187,40 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: |
| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM |
| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png)
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png)
#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png)
### Quantization LLama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| auto-gptq | 199.20 | 232.56 | 253.26 |
| smooth-quant | 142.28 | 222.96 | 300.59 |
| colossal-gptq | 231.98 | 388.87 | 573.03 |
| batch_size | 8 | 16 | 32 |
|:-------------:|:------:|:------:|:------:|
| auto-gptq | 199.20 | 232.56 | 253.26 |
| smooth-quant | 142.28 | 222.96 | 300.59 |
| colossal-gptq | 231.98 | 388.87 | 573.03 |
![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png)

7
colossalai/quantization/__init__.py

@ -0,0 +1,7 @@
from .bnb import quantize_model
from .bnb_config import BnbQuantizationConfig
__all__ = [
"BnbQuantizationConfig",
"quantize_model",
]

321
colossalai/quantization/bnb.py

@ -0,0 +1,321 @@
# adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py
import logging
import torch
import torch.nn as nn
from .bnb_config import BnbQuantizationConfig
try:
import bitsandbytes as bnb
IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0"
IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2"
except ImportError:
pass
logger = logging.getLogger(__name__)
def quantize_model(
model: torch.nn.Module,
bnb_quantization_config: BnbQuantizationConfig,
):
"""
This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
We will quantize the model and put the model on the GPU.
Args:
model (`torch.nn.Module`):
Input model. The model already loaded
bnb_quantization_config (`BnbQuantizationConfig`):
The bitsandbytes quantization parameters
Returns:
`torch.nn.Module`: The quantized model
"""
load_in_4bit = bnb_quantization_config.load_in_4bit
load_in_8bit = bnb_quantization_config.load_in_8bit
if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
raise ImportError(
"You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
" make sure you have the latest version of `bitsandbytes` installed."
)
if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
raise ValueError(
"You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
"make sure you have the latest version of `bitsandbytes` installed."
)
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
modules_to_not_convert = bnb_quantization_config.skip_modules
# We add the modules we want to keep in full precision
if bnb_quantization_config.keep_in_fp32_modules is None:
bnb_quantization_config.keep_in_fp32_modules = []
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
# compatibility with peft
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
# assert model_device is cuda
model_device = next(model.parameters()).device
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
# convert param to the right dtype
dtype = bnb_quantization_config.torch_dtype
for name, param in model.state_dict().items():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.to(torch.float32)
if param.dtype != torch.float32:
name = name.replace(".weight", "").replace(".bias", "")
param = getattr(model, name, None)
if param is not None:
param.to(torch.float32)
elif torch.is_floating_point(param):
param.to(dtype)
if model_device.type == "cuda":
# move everything to cpu in the first place because we can't do quantization if the weights are already on cuda
model.cuda(torch.cuda.current_device())
torch.cuda.empty_cache()
elif torch.cuda.is_available():
model.to(torch.cuda.current_device())
logger.info(
f"The model device type is {model_device.type}. However, cuda is needed for quantization."
"We move the model to cuda."
)
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
return model
def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[str]`):
Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
numerical stability reasons.
current_key_name (`List[str]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
model, has_been_replaced = _replace_with_bnb_layers(
model, bnb_quantization_config, modules_to_not_convert, current_key_name
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def _replace_with_bnb_layers(
model,
bnb_quantization_config,
modules_to_not_convert=None,
current_key_name=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
proceed = True
for key in modules_to_not_convert:
if (
(key in current_key_name_str) and (key + "." in current_key_name_str)
) or key == current_key_name_str:
proceed = False
break
if proceed:
# Load bnb module with empty weight and replace ``nn.Linear` module
if bnb_quantization_config.load_in_8bit:
bnb_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=bnb_quantization_config.llm_int8_threshold,
)
elif bnb_quantization_config.load_in_4bit:
bnb_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
bnb_quantization_config.bnb_4bit_compute_dtype,
compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
quant_type=bnb_quantization_config.bnb_4bit_quant_type,
)
else:
raise ValueError("load_in_8bit and load_in_4bit can't be both False")
bnb_module.weight.data = module.weight.data
bnb_module.weight.skip_zero_check = True
if module.bias is not None:
bnb_module.bias.data = module.bias.data
bnb_module.bias.skip_zero_check = True
bnb_module.requires_grad_(False)
setattr(model, name, bnb_module)
has_been_replaced = True
if len(list(module.children())) > 0:
_, _has_been_replaced = _replace_with_bnb_layers(
module, bnb_quantization_config, modules_to_not_convert, current_key_name
)
has_been_replaced = has_been_replaced | _has_been_replaced
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model
# with init_empty_weights():
# tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model = model
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
else:
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = False
if hasattr(model, "base_model_prefix"):
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
list_modules = list(model.named_children())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
def find_tied_parameters(model: nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
<Tip warning={true}>
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.
</Tip>
Args:
model (`torch.nn.Module`): The model to inspect.
Returns:
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
```py
>>> from collections import OrderedDict
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
named_parameters = kwargs.get("named_parameters", None)
prefix = kwargs.get("prefix", "")
result = kwargs.get("result", {})
if named_parameters is None:
named_parameters = {n: p for n, p in model.named_parameters()}
else:
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
# `named_parameters`.
for name, parameter in model.named_parameters():
full_name = name if prefix == "" else f"{prefix}.{name}"
if full_name not in named_parameters:
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)
# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def values(self):
return sum([x[1:] for x in self], [])

113
colossalai/quantization/bnb_config.py

@ -0,0 +1,113 @@
# adapted from Hugging Face accelerate/utils/dataclasses.py
import warnings
from dataclasses import dataclass, field
from typing import List
import torch
@dataclass
class BnbQuantizationConfig:
"""
A plugin to enable BitsAndBytes 4bit and 8bit quantization
"""
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
bnb_4bit_quant_type: str = field(
default="fp4",
metadata={
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}."
},
)
bnb_4bit_use_double_quant: bool = field(
default=False,
metadata={
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
},
)
bnb_4bit_compute_dtype: bool = field(
default="fp16",
metadata={
"help": "This sets the computational type which might be different than the input time. For example, inputs might be "
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
},
)
torch_dtype: torch.dtype = field(
default=None,
metadata={
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
},
)
skip_modules: List[str] = field(
default=None,
metadata={
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
},
)
keep_in_fp32_modules: List[str] = field(
default=None,
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
)
def __post_init__(self):
if isinstance(self.bnb_4bit_compute_dtype, str):
if self.bnb_4bit_compute_dtype == "fp32":
self.bnb_4bit_compute_dtype = torch.float32
elif self.bnb_4bit_compute_dtype == "fp16":
self.bnb_4bit_compute_dtype = torch.float16
elif self.bnb_4bit_compute_dtype == "bf16":
self.bnb_4bit_compute_dtype = torch.bfloat16
else:
raise ValueError(
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
)
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
raise ValueError("skip_modules must be a list of strings")
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")
if self.load_in_4bit:
self.target_dtype = "int4"
if self.load_in_8bit:
self.target_dtype = torch.int8
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
if isinstance(self.torch_dtype, str):
if self.torch_dtype == "fp32":
self.torch_dtype = torch.float32
elif self.torch_dtype == "fp16":
self.torch_dtype = torch.float16
elif self.torch_dtype == "bf16":
self.torch_dtype = torch.bfloat16
else:
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
if self.load_in_8bit and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.load_in_4bit and self.torch_dtype is None:
self.torch_dtype = self.bnb_4bit_compute_dtype
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")

7
colossalai/zero/low_level/low_level_optim.py

@ -235,9 +235,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
assert (
param.dtype == self._dtype
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size

1
requirements/requirements.txt

@ -18,3 +18,4 @@ google
protobuf
transformers==4.36.2
peft>=0.7.1
bitsandbytes>=0.39.0

106
tests/test_lora/test_lora.py

@ -0,0 +1,106 @@
import copy
import os
from itertools import product
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_configs = [
{
"lora_config": lora_config,
"quantize": False,
},
{
"lora_config": lora_config,
"quantize": True,
},
]
for plugin, test_config in product(test_plugins, test_configs):
# checkpoint loaded model
model_save = model_fn()
model_load = copy.deepcopy(model_save)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
# test fwd bwd correctness
test_model = model_load
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = test_model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
else:
if not p1.requires_grad:
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)

108
tests/test_lora/test_torch_ddp_lora.py

@ -1,108 +0,0 @@
import copy
import os
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import (
assert_equal,
assert_not_equal,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = booster.enable_lora(model, lora_config=lora_config)
model_copy = copy.deepcopy(model)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert_not_equal(p1.to(p2.device), p2)
else:
if not p1.requires_grad:
assert_equal(p1.to(p2.device), p2)
@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
plugin = TorchDDPPlugin()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
model_save = model_fn()
model_load = copy.deepcopy(model_save)
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, lora_config=lora_config)
model_save, _, _, _, _ = booster.boost(model_save)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)
Loading…
Cancel
Save