mirror of https://github.com/hpcaitech/ColossalAI
parent
24e84eba60
commit
d373e67b99
|
@ -1,14 +1,21 @@
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||||
|
enumerate_all_possible_1d_sharding,
|
||||||
|
enumerate_all_possible_2d_sharding,
|
||||||
|
exception_handler,
|
||||||
|
)
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Dict, List
|
from .operator_handler import OperatorHandler
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
|
||||||
|
|
||||||
__all__ = ['WhereHandler']
|
__all__ = ['WhereHandler']
|
||||||
|
|
||||||
|
@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
|
||||||
# compute the resharding cost
|
# compute the resharding cost
|
||||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||||
input_sharding_spec, input_spec)
|
input_sharding_spec, input_spec)
|
||||||
|
total_resharding_cost = total_resharding_cost['total']
|
||||||
# we need multiply the size of elem dtype to get correct communication cost
|
# we need multiply the size of elem dtype to get correct communication cost
|
||||||
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
resharding_cost = total_resharding_cost * size_per_elem_bytes
|
||||||
resharding_costs[input_node].append(resharding_cost)
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
|
|
Loading…
Reference in New Issue