mirror of https://github.com/hpcaitech/ColossalAI
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
from .bf16 import BF16MixedPrecision
|
|
from .fp8 import FP8MixedPrecision
|
|
from .fp16_apex import FP16ApexMixedPrecision
|
|
from .fp16_naive import FP16NaiveMixedPrecision
|
|
from .fp16_torch import FP16TorchMixedPrecision
|
|
from .mixed_precision_base import MixedPrecision
|
|
|
|
__all__ = [
|
|
"MixedPrecision",
|
|
"mixed_precision_factory",
|
|
"FP16_Apex_MixedPrecision",
|
|
"FP16_Torch_MixedPrecision",
|
|
"FP32_MixedPrecision",
|
|
"BF16_MixedPrecision",
|
|
"FP8_MixedPrecision",
|
|
"FP16NaiveMixedPrecision",
|
|
]
|
|
|
|
_mixed_precision_mapping = {
|
|
"fp16": FP16TorchMixedPrecision,
|
|
"fp16_apex": FP16ApexMixedPrecision,
|
|
"fp16_naive": FP16NaiveMixedPrecision,
|
|
"bf16": BF16MixedPrecision,
|
|
"fp8": FP8MixedPrecision,
|
|
}
|
|
|
|
|
|
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
|
|
"""
|
|
Factory method to create mixed precision object
|
|
|
|
Args:
|
|
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
|
|
"""
|
|
|
|
if mixed_precision_type in _mixed_precision_mapping:
|
|
return _mixed_precision_mapping[mixed_precision_type]()
|
|
else:
|
|
raise ValueError(
|
|
f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
|
|
)
|