[hotfix] resharding cost issue (#1742)

pull/1728/head v0.1.11rc1
YuliangLiu0306 2 years ago committed by GitHub
parent 24e84eba60
commit d373e67b99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,21 @@
import operator import operator
from functools import reduce
import warnings import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
exception_handler,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
__all__ = ['WhereHandler'] __all__ = ['WhereHandler']
@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
# compute the resharding cost # compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency( _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec) input_sharding_spec, input_spec)
total_resharding_cost = total_resharding_cost['total']
# we need multiply the size of elem dtype to get correct communication cost # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)

Loading…
Cancel
Save