[zero] use factory pattern for tensor_placement_policy (#752)

pull/753/head
Jiarui Fang 2022-04-14 11:07:29 +08:00 committed by GitHub
parent 4b048a8728
commit 3d7dc46d33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 14 deletions

View File

@ -23,7 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
@ -105,8 +105,6 @@ class ShardedModelV2(nn.Module):
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}'
# Init Memory Statistics Collector
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_model(self)
@ -115,8 +113,8 @@ class ShardedModelV2(nn.Module):
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
self._memstats_collector = None
self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy](
mem_stats_collector=self._memstats_collector)
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
for param in module.parameters():
if hasattr(param, 'colo_attr'):

View File

@ -1,4 +1,5 @@
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .zero_hook import ZeroHook
__all__ = ['StatefulTensorMgr', 'ZeroHook']
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory']

View File

@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict
import torch
from colossalai.utils import get_current_device
@ -6,16 +7,16 @@ from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
__all__ = ['TENSOR_PLACEMENT_POLICIES']
from typing import Type
class TensorPlacementPolicy:
class TensorPlacementPolicy(ABC):
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
self.device: Optional[torch.device] = device
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
@abstractmethod
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
raise NotImplementedError
@ -87,8 +88,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
)
TENSOR_PLACEMENT_POLICIES = {
'cpu': CPUTensorPlacementPolicy,
'cuda': CUDATensorPlacementPolicy,
'auto': AutoTensorPlacementPolicy
}
class TensorPlacementPolicyFactory:
@staticmethod
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
if policy_name == 'cpu':
return CPUTensorPlacementPolicy
elif policy_name == 'cuda':
return CUDATensorPlacementPolicy
elif policy_name == 'auto':
return AutoTensorPlacementPolicy
else:
raise TypeError(f"Unknown tensor placement policy {policy_name}")