|
|
|
from .bf16 import BF16MixedPrecision
|
|
|
|
from .fp8 import FP8MixedPrecision
|
|
|
|
from .fp16_apex import FP16ApexMixedPrecision
|
|
|
|
from .fp16_naive import FP16NaiveMixedPrecision
|
|
|
|
from .fp16_torch import FP16TorchMixedPrecision
|
|
|
|
from .mixed_precision_base import MixedPrecision
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
|
|
|
|
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision'
|
|
|
|
]
|
|
|
|
|
|
|
|
_mixed_precision_mapping = {
|
|
|
|
'fp16': FP16TorchMixedPrecision,
|
|
|
|
'fp16_apex': FP16ApexMixedPrecision,
|
|
|
|
'fp16_naive': FP16NaiveMixedPrecision,
|
|
|
|
'bf16': BF16MixedPrecision,
|
|
|
|
'fp8': FP8MixedPrecision
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
|
|
|
|
"""
|
|
|
|
Factory method to create mixed precision object
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if mixed_precision_type in _mixed_precision_mapping:
|
|
|
|
return _mixed_precision_mapping[mixed_precision_type]()
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
|
|
|
|
)
|