from .gemini_plugin import GeminiPlugin from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin __all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] import torch from packaging import version if version.parse(torch.__version__) >= version.parse("1.12.0"): from .torch_fsdp_plugin import TorchFSDPPlugin __all__.append("TorchFSDPPlugin")