from .bf16 import BF16MixedPrecision
from .fp8 import FP8MixedPrecision
from .fp16_apex import FP16ApexMixedPrecision
from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision

__all__ = [
    'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
    'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
]

_mixed_precision_mapping = {
    'fp16': FP16TorchMixedPrecision,
    'fp16_apex': FP16ApexMixedPrecision,
    'bf16': BF16MixedPrecision,
    'fp8': FP8MixedPrecision
}


def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
    """
    Factory method to create mixed precision object

    Args:
        mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
    """

    if mixed_precision_type in _mixed_precision_mapping:
        return _mixed_precision_mapping[mixed_precision_type]()
    else:
        raise ValueError(
            f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
        )