mirror of https://github.com/hpcaitech/ColossalAI
[zero] use factory pattern for tensor_placement_policy (#752)
parent
4b048a8728
commit
3d7dc46d33
|
@ -23,7 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy
|
||||
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
@ -105,8 +105,6 @@ class ShardedModelV2(nn.Module):
|
|||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}'
|
||||
# Init Memory Statistics Collector
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||
|
@ -115,8 +113,8 @@ class ShardedModelV2(nn.Module):
|
|||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy](
|
||||
mem_stats_collector=self._memstats_collector)
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
from .zero_hook import ZeroHook
|
||||
|
||||
__all__ = ['StatefulTensorMgr', 'ZeroHook']
|
||||
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory']
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -6,16 +7,16 @@ from colossalai.utils.memory import colo_device_memory_capacity
|
|||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
__all__ = ['TENSOR_PLACEMENT_POLICIES']
|
||||
from typing import Type
|
||||
|
||||
|
||||
class TensorPlacementPolicy:
|
||||
class TensorPlacementPolicy(ABC):
|
||||
|
||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
self.device: Optional[torch.device] = device
|
||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -87,8 +88,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
)
|
||||
|
||||
|
||||
TENSOR_PLACEMENT_POLICIES = {
|
||||
'cpu': CPUTensorPlacementPolicy,
|
||||
'cuda': CUDATensorPlacementPolicy,
|
||||
'auto': AutoTensorPlacementPolicy
|
||||
}
|
||||
class TensorPlacementPolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
|
||||
if policy_name == 'cpu':
|
||||
return CPUTensorPlacementPolicy
|
||||
elif policy_name == 'cuda':
|
||||
return CUDATensorPlacementPolicy
|
||||
elif policy_name == 'auto':
|
||||
return AutoTensorPlacementPolicy
|
||||
else:
|
||||
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
||||
|
|
Loading…
Reference in New Issue