diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py new file mode 100644 index 000000000..8fe68cbb3 --- /dev/null +++ b/colossalai/gemini/__init__.py @@ -0,0 +1,4 @@ +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import TensorPlacementPolicyFactory + +__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory'] \ No newline at end of file diff --git a/colossalai/zero/utils/stateful_tensor_mgr.py b/colossalai/gemini/stateful_tensor_mgr.py similarity index 97% rename from colossalai/zero/utils/stateful_tensor_mgr.py rename to colossalai/gemini/stateful_tensor_mgr.py index 107d14f5b..6344e3b6d 100644 --- a/colossalai/zero/utils/stateful_tensor_mgr.py +++ b/colossalai/gemini/stateful_tensor_mgr.py @@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy from typing import List from colossalai.logging import get_dist_logger diff --git a/colossalai/zero/utils/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py similarity index 100% rename from colossalai/zero/utils/tensor_placement_policy.py rename to colossalai/gemini/tensor_placement_policy.py diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 4ace8a4d3..1db81991c 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -22,8 +22,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index c3c1723d2..680f86962 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -21,7 +21,7 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer -from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy class OptimState(Enum): diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py index 02bc21873..c4e687228 100644 --- a/colossalai/zero/utils/__init__.py +++ b/colossalai/zero/utils/__init__.py @@ -1,5 +1,3 @@ -from .stateful_tensor_mgr import StatefulTensorMgr -from .tensor_placement_policy import TensorPlacementPolicyFactory from .zero_hook import ZeroHook -__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory'] \ No newline at end of file +__all__ = ['ZeroHook'] \ No newline at end of file diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index 40b44fc12..fbbb7fd2d 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -9,8 +9,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param.tensorful_state import TensorState -from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.engine.ophooks import BaseOpHook diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index 15ca1cc5c..6449c285f 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_set_process_memory_fraction -from colossalai.zero.utils import StatefulTensorMgr +from colossalai.gemini import StatefulTensorMgr from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.utils import free_port @@ -14,7 +14,9 @@ from colossalai.testing import rerun_on_exception from torch.nn.parameter import Parameter from typing import List from functools import partial -from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy + +from colossalai.gemini import StatefulTensorMgr +from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy class Net(torch.nn.Module):