diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 04608184a..e5de4982f 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -23,7 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy +from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) @@ -105,8 +105,6 @@ class ShardedModelV2(nn.Module): self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy - assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}' - # Init Memory Statistics Collector self._use_memory_tracer = tensor_placement_policy == 'auto' if self._use_memory_tracer: GLOBAL_MODEL_DATA_TRACER.register_model(self) @@ -115,8 +113,8 @@ class ShardedModelV2(nn.Module): self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: self._memstats_collector = None - self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy]( - mem_stats_collector=self._memstats_collector) + self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( + tensor_placement_policy)(mem_stats_collector=self._memstats_collector) self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) for param in module.parameters(): if hasattr(param, 'colo_attr'): diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py index 2153ebe34..02bc21873 100644 --- a/colossalai/zero/utils/__init__.py +++ b/colossalai/zero/utils/__init__.py @@ -1,4 +1,5 @@ from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import TensorPlacementPolicyFactory from .zero_hook import ZeroHook -__all__ = ['StatefulTensorMgr', 'ZeroHook'] \ No newline at end of file +__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory'] \ No newline at end of file diff --git a/colossalai/zero/utils/tensor_placement_policy.py b/colossalai/zero/utils/tensor_placement_policy.py index 953fd956c..d7a977188 100644 --- a/colossalai/zero/utils/tensor_placement_policy.py +++ b/colossalai/zero/utils/tensor_placement_policy.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import List, Optional, Dict import torch from colossalai.utils import get_current_device @@ -6,16 +7,16 @@ from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER - -__all__ = ['TENSOR_PLACEMENT_POLICIES'] +from typing import Type -class TensorPlacementPolicy: +class TensorPlacementPolicy(ABC): def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: self.device: Optional[torch.device] = device self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector + @abstractmethod def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: raise NotImplementedError @@ -87,8 +88,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): ) -TENSOR_PLACEMENT_POLICIES = { - 'cpu': CPUTensorPlacementPolicy, - 'cuda': CUDATensorPlacementPolicy, - 'auto': AutoTensorPlacementPolicy -} +class TensorPlacementPolicyFactory: + + @staticmethod + def create(policy_name: str) -> Type[TensorPlacementPolicy]: + if policy_name == 'cpu': + return CPUTensorPlacementPolicy + elif policy_name == 'cuda': + return CUDATensorPlacementPolicy + elif policy_name == 'auto': + return AutoTensorPlacementPolicy + else: + raise TypeError(f"Unknown tensor placement policy {policy_name}")