diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index 9c7407eb5..7a5a44ebb 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,8 +1,9 @@ -from .chunk import ChunkManager, TensorInfo, TensorState +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .gemini_mgr import GeminiManager from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory __all__ = [ - 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager' + 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', + 'search_chunk_configuration' ]