[hotfix] solver bug caused by dict type comm cost (#1686)

pull/1690/head
YuliangLiu0306 2022-10-11 17:57:03 +08:00 committed by GitHub
parent 3dd6994427
commit 6878e42248
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 1 deletions

View File

@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
# softmax should not be here
torch.nn.functional.softmax
]

View File

@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler):
shape_consistency_manager = ShapeConsistencyManager()
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
replicate_input_sharding_spec)
communication_cost = communication_cost["total"]
# generate resharding cost
resharding_costs = self._generate_resharding_costs([input_sharding_spec])

View File

@ -319,6 +319,8 @@ class Solver:
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################