ColossalAI/colossalai/booster/plugin/__init__.py

15 lines
528 B
Python
Raw Normal View History

from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('1.12.0'):
from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append('TorchFSDPPlugin')