Browse Source

[hotfix] resharding cost issue (#1742)

pull/1728/head v0.1.11rc1
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
d373e67b99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py

19
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py

@ -1,14 +1,21 @@
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
exception_handler,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
from .operator_handler import OperatorHandler
__all__ = ['WhereHandler']
@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
total_resharding_cost = total_resharding_cost['total']
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)

Loading…
Cancel
Save