[gemini] init genimi individual directory (#754)

pull/764/head
Jiarui Fang 2022-04-14 16:40:26 +08:00 committed by GitHub
parent dcca614eee
commit 10ef8afdd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 14 additions and 11 deletions

View File

@ -0,0 +1,4 @@
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory']

View File

@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger

View File

@ -22,8 +22,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)

View File

@ -21,7 +21,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
class OptimState(Enum):

View File

@ -1,5 +1,3 @@
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .zero_hook import ZeroHook
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory']
__all__ = ['ZeroHook']

View File

@ -9,8 +9,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.engine.ophooks import BaseOpHook

View File

@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.zero.utils import StatefulTensorMgr
from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port
@ -14,7 +14,9 @@ from colossalai.testing import rerun_on_exception
from torch.nn.parameter import Parameter
from typing import List
from functools import partial
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
class Net(torch.nn.Module):