mirror of https://github.com/hpcaitech/ColossalAI
[gemini] init genimi individual directory (#754)
parent
dcca614eee
commit
10ef8afdd2
|
@ -0,0 +1,4 @@
|
|||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
|
||||
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory']
|
|
@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
|
@ -22,8 +22,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
|
|
@ -21,7 +21,7 @@ from torch import Tensor
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
from .zero_hook import ZeroHook
|
||||
|
||||
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory']
|
||||
__all__ = ['ZeroHook']
|
|
@ -9,8 +9,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|||
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||
from colossalai.zero.utils import StatefulTensorMgr
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
|
@ -14,7 +14,9 @@ from colossalai.testing import rerun_on_exception
|
|||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue