From d373e67b993ebee3a4f04df05986f8f7abf7229d Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 19 Oct 2022 11:33:43 +0800 Subject: [PATCH] [hotfix] resharding cost issue (#1742) --- .../deprecated/op_handler/where_handler.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py index dddd91786..bd97e2736 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py @@ -1,14 +1,21 @@ import operator -from functools import reduce import warnings +from copy import deepcopy +from functools import reduce +from typing import Dict, List + import torch + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + exception_handler, +) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector -from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from copy import deepcopy -from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding + +from .operator_handler import OperatorHandler __all__ = ['WhereHandler'] @@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler): # compute the resharding cost _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( input_sharding_spec, input_spec) - + total_resharding_cost = total_resharding_cost['total'] # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost)