mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] solver bug caused by dict type comm cost (#1686)
parent
3dd6994427
commit
6878e42248
|
@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [
|
|||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
torch.flatten,
|
||||
# softmax should not be here
|
||||
torch.nn.functional.softmax
|
||||
]
|
||||
|
|
|
@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler):
|
|||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
||||
replicate_input_sharding_spec)
|
||||
communication_cost = communication_cost["total"]
|
||||
|
||||
# generate resharding cost
|
||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||
|
|
|
@ -319,6 +319,8 @@ class Solver:
|
|||
obj = 0
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
assert len(s[i]) == len(d[i])
|
||||
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
|
|
Loading…
Reference in New Issue