From 7afbc81d6292f1a44cb5c2f89571c6c1c6d74691 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 4 Jul 2024 11:33:23 +0800 Subject: [PATCH] [quant] fix bitsandbytes version check (#5882) * [quant] fix bitsandbytes version check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/quantization/bnb.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py index fa214116a..3601ef62b 100644 --- a/colossalai/quantization/bnb.py +++ b/colossalai/quantization/bnb.py @@ -1,17 +1,25 @@ # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py +import importlib.metadata import logging import torch import torch.nn as nn +from packaging.version import Version from .bnb_config import BnbQuantizationConfig try: import bitsandbytes as bnb - IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" - IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" + try: + # in case lower version of bitsandbytes does not have __version__ attribute + BNB_VERSION = Version(bnb.__version__) + except AttributeError: + BNB_VERSION = Version(importlib.metadata.version("bitsandbytes")) + + IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0") + IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2") except ImportError: pass