[autoparallel] Patch meta information for nodes that will not be handled by SPMD solver (#2823)

* [autoparallel] non spmd meta information generator

* [autoparallel] patch meta information for non spmd nodes
pull/2863/head
Boyuan Yao 2023-02-22 10:28:56 +08:00 committed by GitHub
parent c7764d3f22
commit eae77c831d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 0 deletions

View File

@ -3,6 +3,7 @@ from .binary_elementwise_ops import *
from .conv import *
from .embedding import *
from .linear import *
from .non_spmd import *
from .norm import *
from .pooling import *
from .tensor import *

View File

@ -0,0 +1,29 @@
import operator
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["non_spmd_meta_info"]
@meta_register.register(torch.Size)
@meta_register.register(torch.Tensor.size)
@meta_register.register(torch.finfo)
@meta_register.register(operator.le)
def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Non-SPMD node meta information generator
Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost())
fwd_in, fwd_buffer, fwd_out = [], [], []
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out