mirror of https://github.com/hpcaitech/ColossalAI
[zero] use factory pattern for tensor_placement_policy (#752)
parent
4b048a8728
commit
3d7dc46d33
|
@ -23,7 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy
|
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
@ -105,8 +105,6 @@ class ShardedModelV2(nn.Module):
|
||||||
self.rank = dist.get_rank(self.process_group)
|
self.rank = dist.get_rank(self.process_group)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
|
|
||||||
assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}'
|
|
||||||
# Init Memory Statistics Collector
|
|
||||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||||
|
@ -115,8 +113,8 @@ class ShardedModelV2(nn.Module):
|
||||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||||
else:
|
else:
|
||||||
self._memstats_collector = None
|
self._memstats_collector = None
|
||||||
self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy](
|
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||||
mem_stats_collector=self._memstats_collector)
|
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
if hasattr(param, 'colo_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||||
from .zero_hook import ZeroHook
|
from .zero_hook import ZeroHook
|
||||||
|
|
||||||
__all__ = ['StatefulTensorMgr', 'ZeroHook']
|
__all__ = ['StatefulTensorMgr', 'ZeroHook', 'TensorPlacementPolicyFactory']
|
|
@ -1,3 +1,4 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
import torch
|
import torch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -6,16 +7,16 @@ from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from typing import Type
|
||||||
__all__ = ['TENSOR_PLACEMENT_POLICIES']
|
|
||||||
|
|
||||||
|
|
||||||
class TensorPlacementPolicy:
|
class TensorPlacementPolicy(ABC):
|
||||||
|
|
||||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||||
self.device: Optional[torch.device] = device
|
self.device: Optional[torch.device] = device
|
||||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -87,8 +88,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TENSOR_PLACEMENT_POLICIES = {
|
class TensorPlacementPolicyFactory:
|
||||||
'cpu': CPUTensorPlacementPolicy,
|
|
||||||
'cuda': CUDATensorPlacementPolicy,
|
@staticmethod
|
||||||
'auto': AutoTensorPlacementPolicy
|
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
|
||||||
}
|
if policy_name == 'cpu':
|
||||||
|
return CPUTensorPlacementPolicy
|
||||||
|
elif policy_name == 'cuda':
|
||||||
|
return CUDATensorPlacementPolicy
|
||||||
|
elif policy_name == 'auto':
|
||||||
|
return AutoTensorPlacementPolicy
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
||||||
|
|
Loading…
Reference in New Issue