diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py index dddd91786..bd97e2736 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py @@ -1,14 +1,21 @@ import operator -from functools import reduce import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + import torch + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + exception_handler, +) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from copy import deepcopy -from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding + +from .operator_handler import OperatorHandler __all__ = ['WhereHandler'] @@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler): # compute the resharding cost _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( input_sharding_spec, input_spec) - + total_resharding_cost = total_resharding_cost['total'] # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost)