diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py index d005ac813..4d8b656e1 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -3,6 +3,7 @@ from .binary_elementwise_ops import * from .conv import * from .embedding import * from .linear import * +from .non_spmd import * from .norm import * from .pooling import * from .tensor import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py new file mode 100644 index 000000000..4634d3ccd --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -0,0 +1,29 @@ +import operator +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["non_spmd_meta_info"] + + +@meta_register.register(torch.Size) +@meta_register.register(torch.Tensor.size) +@meta_register.register(torch.finfo) +@meta_register.register(operator.le) +def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """Non-SPMD node meta information generator + Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost()) + fwd_in, fwd_buffer, fwd_out = [], [], [] + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out