diff --git a/colossalai/quantization/bnb.py b/colossalai/quantization/bnb.py index fa214116a..3601ef62b 100644 --- a/colossalai/quantization/bnb.py +++ b/colossalai/quantization/bnb.py @@ -1,17 +1,25 @@ # adapted from Hugging Face accelerate/utils/bnb.py accelerate/utils/modeling.py +import importlib.metadata import logging import torch import torch.nn as nn +from packaging.version import Version from .bnb_config import BnbQuantizationConfig try: import bitsandbytes as bnb - IS_4BIT_BNB_AVAILABLE = bnb.__version__ >= "0.39.0" - IS_8BIT_BNB_AVAILABLE = bnb.__version__ >= "0.37.2" + try: + # in case lower version of bitsandbytes does not have __version__ attribute + BNB_VERSION = Version(bnb.__version__) + except AttributeError: + BNB_VERSION = Version(importlib.metadata.version("bitsandbytes")) + + IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0") + IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2") except ImportError: pass