ColossalAI/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py

28 lines
1.0 KiB
Python
Raw Normal View History

import operator
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
__all__ = ["non_spmd_meta_info"]
@meta_register.register(torch.Size)
@meta_register.register(torch.Tensor.size)
@meta_register.register(torch.finfo)
@meta_register.register(operator.le)
def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Non-SPMD node meta information generator
Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost())
fwd_in, fwd_buffer, fwd_out = [], [], []
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out