finish node reorder

pull/2364/head
oahzxl 2022-12-23 17:25:36 +08:00
parent 884a228ea6
commit 51ef8384c1
1 changed files with 15 additions and 16 deletions

View File

@ -1238,7 +1238,7 @@ class MemoryEstimator(object):
def estimate_chunk_inference_mem(
self,
gm: torch.fx.GraphModule,
node_list,
chunk_infos=None,
):
act_memory = 0.0
@ -1247,7 +1247,6 @@ class MemoryEstimator(object):
active_node_list = []
active_node_list_log = []
not_contiguous_list = []
node_list = list(gm.graph.nodes)
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
@ -1281,7 +1280,6 @@ class MemoryEstimator(object):
) / (1024**2)
# determine chunk ratio for current node
# TODO: adapt to prepose node memory
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node,
@ -1371,10 +1369,7 @@ class MemoryEstimator(object):
class ChunkRegionSearch(object):
def __init__(self, gm) -> None:
self.gm = gm
self.node_list = list(gm.graph.nodes)
self.index_tracer = IndexTracer(
self.node_list
) # node list shared in index tracer
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
@ -1385,7 +1380,7 @@ class ChunkRegionSearch(object):
def _get_free_var(self):
free_var_idx = []
for idx, n in enumerate(self.node_list):
for idx, n in enumerate(self.index_tracer.node_list):
if n.op == "placeholder":
free_var_idx.append(idx)
return free_var_idx
@ -1455,13 +1450,13 @@ class ChunkRegionSearch(object):
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.node_list[end_idx]
end_node = self.index_tracer.node_list[end_idx]
chunk_infos = []
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
for end_dim, _ in enumerate(end_trace["idx"]):
if len(start_traces) > 1:
continue
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
for start_dim, _ in enumerate(start_trace["idx"]):
# dim size cannot be 1
if (
_get_node_shape(end_node)[end_dim] == 1
@ -1494,7 +1489,7 @@ class ChunkRegionSearch(object):
possible_chunk_region = []
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_list):
for _, n in enumerate(self.index_tracer.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(
@ -1507,8 +1502,8 @@ class ChunkRegionSearch(object):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if _is_non_compute_node(
self.node_list[start_idx]
) or _is_non_compute_node(self.node_list[end_idx]):
self.index_tracer.node_list[start_idx]
) or _is_non_compute_node(self.index_tracer.node_list[end_idx]):
continue
# select free dim
@ -1577,7 +1572,9 @@ class ChunkRegionSearch(object):
init_mem_peak,
_,
active_node,
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
) = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list
)
mem_peak = init_mem_peak
while True:
@ -1590,7 +1587,9 @@ class ChunkRegionSearch(object):
mem_peak,
_,
active_node,
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos)
) = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos
)
if self._stop_search(init_mem_peak, mem_peak):
break
return chunk_infos