from typing import Callable, List import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, OperationDataType, ShardingStrategy, StrategiesVector, TrainCycleItem, ) from colossalai.tensor.sharding_spec import ShardingSpec from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register __all__ = ['MetaInfo'] class MetaInfo: """MetaInfo class This class is used to store meta info based on sharding strategy and the given target function. """ def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None: # compute cost of forward and backward computation self.compute_cost: TrainCycleItem # compute memory cost of forward and backward phase self.memory_cost: TrainCycleItem # list of input tensors self.fwd_in: List[torch.Tensor] # list of buffer tensors self.fwd_buffer: List[torch.Tensor] # list of output tensors self.fwd_out: List[torch.Tensor] # sharding strategy self._strategy = strategy # target function self._target = target # compute metainfo if possible if self._strategy is not None and self._target is not None: self.compute_metainfo() @property def strategy(self) -> ShardingStrategy: return self._strategy @property def target(self) -> Callable: return self._target @strategy.setter def strategy(self, strategy: ShardingStrategy) -> None: self._strategy = strategy if self._strategy is not None and self._target is not None: self.compute_metainfo() @target.setter def target(self, target: Callable) -> None: self._target = target if self._strategy is not None and self._target is not None: self.compute_metainfo() def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: """ Compute sharded opdata based on the given data and sharding spec. """ return OperationData(name=operation_data.name, data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), type=operation_data.type, logical_shape=operation_data.logical_shape) def compute_metainfo(self): """ Compute meta info based on sharding strategy and the given target function. """ assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ f"Meta info for {self._target} is not registered." if meta_register.has(self._target.__class__): # module meta_func = meta_register.get(self._target.__class__) # check whether the target in the list that we don't need to save activation save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION else: # function meta_func = meta_register.get(self._target) # check whether the target in the list that we don't need to save activation save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION # construct args for meta_func args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()] # construct kwargs if self.target in INPLACE_MODULE: kwargs = {'inplace': self.target.inplace} elif self.target in INPLACE_OPS: kwargs = {'inplace': True} else: kwargs = {'inplace': False} # compute metainfo with meta_func self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) # process corner case for NO_SAVE_ACTIVATION if not save_fwd_in: self.fwd_in = []