from .bf16 import BF16MixedPrecision from .fp8 import FP8MixedPrecision from .fp16_apex import FP16ApexMixedPrecision from .fp16_torch import FP16TorchMixedPrecision from .mixed_precision_base import MixedPrecision __all__ = [ 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' ] _mixed_precision_mapping = { 'fp16': FP16TorchMixedPrecision, 'fp16_apex': FP16ApexMixedPrecision, 'bf16': BF16MixedPrecision, 'fp8': FP8MixedPrecision } def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: """ Factory method to create mixed precision object Args: mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'. """ if mixed_precision_type in _mixed_precision_mapping: return _mixed_precision_mapping[mixed_precision_type]() else: raise ValueError( f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' )