|
|
|
@ -32,7 +32,7 @@ class Solver:
|
|
|
|
|
graph: Graph, |
|
|
|
|
strategies_constructor: StrategiesConstructor, |
|
|
|
|
cost_graph: CostGraph, |
|
|
|
|
graph_analyser: GraphAnalyser, |
|
|
|
|
graph_analyser: GraphAnalyser = None, |
|
|
|
|
memory_budget: float = -1.0, |
|
|
|
|
solution_numbers: int = 1, |
|
|
|
|
forward_only: bool = False, |
|
|
|
@ -63,7 +63,10 @@ class Solver:
|
|
|
|
|
self.memory_increasing_coefficient = memory_increasing_coefficient |
|
|
|
|
else: |
|
|
|
|
self.memory_increasing_coefficient = 1 |
|
|
|
|
self.liveness_list = self.graph_analyser.liveness_analysis() |
|
|
|
|
# temporarily we use all nodes as liveness list, we count the backward memory cost together with |
|
|
|
|
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. |
|
|
|
|
# self.liveness_list = self.graph_analyser.liveness_analysis() |
|
|
|
|
self.liveness_list = self.nodes |
|
|
|
|
self.node_index_dict = self._generate_node_index_dict() |
|
|
|
|
# The last solution vector of auto sharding. |
|
|
|
|
self.last_s_val = None |
|
|
|
@ -140,7 +143,7 @@ class Solver:
|
|
|
|
|
liveness_set = self.liveness_list |
|
|
|
|
|
|
|
|
|
# omit alias_set now |
|
|
|
|
alias_set = None |
|
|
|
|
alias_set = self.strategies_constructor.alias_set |
|
|
|
|
alias_convert_costs = None |
|
|
|
|
|
|
|
|
|
# prepare compute_costs, communication_costs and memory_costs |
|
|
|
@ -230,6 +233,7 @@ class Solver:
|
|
|
|
|
|
|
|
|
|
# 0. Unpack flatten numpy arrays |
|
|
|
|
s_follow = following_nodes |
|
|
|
|
s_alias = alias_set |
|
|
|
|
|
|
|
|
|
E = edge_pairs.reshape((-1, 2)) # noqa |
|
|
|
|
r = [] |
|
|
|
@ -294,8 +298,11 @@ class Solver:
|
|
|
|
|
if strategies_len[i] == 1: |
|
|
|
|
s.append([1]) |
|
|
|
|
else: |
|
|
|
|
num_nodes += 1 |
|
|
|
|
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) |
|
|
|
|
if i not in s_alias: |
|
|
|
|
num_nodes += 1 |
|
|
|
|
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) |
|
|
|
|
else: |
|
|
|
|
s.append(s[s_alias[i]]) |
|
|
|
|
else: |
|
|
|
|
if s_follow[i] < len(s): |
|
|
|
|
s.append(s[s_follow[i]]) |
|
|
|
@ -311,15 +318,20 @@ class Solver:
|
|
|
|
|
############################# |
|
|
|
|
e = [] |
|
|
|
|
num_edges = 0 |
|
|
|
|
map_edge_to_idx = {} |
|
|
|
|
for (idx, (i, j)) in enumerate(E): |
|
|
|
|
if len(s[i]) == 1: |
|
|
|
|
e.append(s[j]) |
|
|
|
|
elif len(s[j]) == 1: |
|
|
|
|
e.append(s[i]) |
|
|
|
|
else: |
|
|
|
|
num_edges += 1 |
|
|
|
|
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) |
|
|
|
|
if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx: |
|
|
|
|
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]]) |
|
|
|
|
else: |
|
|
|
|
num_edges += 1 |
|
|
|
|
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) |
|
|
|
|
assert len(e[idx]) == len(r[idx]) |
|
|
|
|
map_edge_to_idx[(i, j)] = idx |
|
|
|
|
for element in s: |
|
|
|
|
assert len(element) > 0 |
|
|
|
|
# 2. Set initial value |
|
|
|
@ -371,13 +383,12 @@ class Solver:
|
|
|
|
|
# compute memory consumption with liveness set # |
|
|
|
|
################################################# |
|
|
|
|
if memory_budget > 0: |
|
|
|
|
for liveness_stage in liveness_set: |
|
|
|
|
mem = 0 |
|
|
|
|
for live_variable in liveness_stage.unique_live_vars: |
|
|
|
|
if live_variable.node not in self.node_index_dict: |
|
|
|
|
continue |
|
|
|
|
node_index = self.node_index_dict[live_variable.node] |
|
|
|
|
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) |
|
|
|
|
mem = 0 |
|
|
|
|
for node in liveness_set: |
|
|
|
|
if node not in self.node_index_dict: |
|
|
|
|
continue |
|
|
|
|
node_index = self.node_index_dict[node] |
|
|
|
|
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) |
|
|
|
|
prob += mem <= memory_budget |
|
|
|
|
|
|
|
|
|
# (d). specified by `cat="Binary"` |
|
|
|
|