|
|
|
@ -1,14 +1,21 @@
|
|
|
|
|
import operator |
|
|
|
|
from functools import reduce |
|
|
|
|
import warnings |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from functools import reduce |
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( |
|
|
|
|
enumerate_all_possible_1d_sharding, |
|
|
|
|
enumerate_all_possible_2d_sharding, |
|
|
|
|
exception_handler, |
|
|
|
|
) |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector |
|
|
|
|
from .operator_handler import OperatorHandler |
|
|
|
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager |
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from typing import Dict, List |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding |
|
|
|
|
|
|
|
|
|
from .operator_handler import OperatorHandler |
|
|
|
|
|
|
|
|
|
__all__ = ['WhereHandler'] |
|
|
|
|
|
|
|
|
@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
|
|
|
|
|
# compute the resharding cost |
|
|
|
|
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency( |
|
|
|
|
input_sharding_spec, input_spec) |
|
|
|
|
|
|
|
|
|
total_resharding_cost = total_resharding_cost['total'] |
|
|
|
|
# we need multiply the size of elem dtype to get correct communication cost |
|
|
|
|
resharding_cost = total_resharding_cost * size_per_elem_bytes |
|
|
|
|
resharding_costs[input_node].append(resharding_cost) |
|
|
|
|