ColossalAI/colossalai/booster/mixed_precision/__init__.py

42 lines
1.3 KiB
Python
Raw Normal View History

from .bf16 import BF16MixedPrecision
from .fp8 import FP8MixedPrecision
from .fp16_apex import FP16ApexMixedPrecision
from .fp16_naive import FP16NaiveMixedPrecision
from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision
__all__ = [
"MixedPrecision",
"mixed_precision_factory",
"FP16_Apex_MixedPrecision",
"FP16_Torch_MixedPrecision",
"FP32_MixedPrecision",
"BF16_MixedPrecision",
"FP8_MixedPrecision",
"FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
"fp16": FP16TorchMixedPrecision,
"fp16_apex": FP16ApexMixedPrecision,
"fp16_naive": FP16NaiveMixedPrecision,
"bf16": BF16MixedPrecision,
"fp8": FP8MixedPrecision,
}
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
"""
Factory method to create mixed precision object
Args:
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
"""
if mixed_precision_type in _mixed_precision_mapping:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)