diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index a82640d67..9c7407eb5 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,6 +1,8 @@ -from .chunk import TensorInfo, TensorState +from .chunk import ChunkManager, TensorInfo, TensorState +from .gemini_mgr import GeminiManager from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory -from .gemini_mgr import GeminiManager -__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState'] +__all__ = [ + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager' +]