mirror of https://github.com/hpcaitech/ColossalAI
finish node reorder
parent
884a228ea6
commit
51ef8384c1
|
@ -1238,7 +1238,7 @@ class MemoryEstimator(object):
|
|||
|
||||
def estimate_chunk_inference_mem(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
node_list,
|
||||
chunk_infos=None,
|
||||
):
|
||||
act_memory = 0.0
|
||||
|
@ -1247,7 +1247,6 @@ class MemoryEstimator(object):
|
|||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
not_contiguous_list = []
|
||||
node_list = list(gm.graph.nodes)
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
|
@ -1281,7 +1280,6 @@ class MemoryEstimator(object):
|
|||
) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
# TODO: adapt to prepose node memory
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
|
@ -1371,10 +1369,7 @@ class MemoryEstimator(object):
|
|||
class ChunkRegionSearch(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.index_tracer = IndexTracer(
|
||||
self.node_list
|
||||
) # node list shared in index tracer
|
||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||
self.index_tracer.trace_index()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
|
||||
|
@ -1385,7 +1380,7 @@ class ChunkRegionSearch(object):
|
|||
|
||||
def _get_free_var(self):
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.node_list):
|
||||
for idx, n in enumerate(self.index_tracer.node_list):
|
||||
if n.op == "placeholder":
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
@ -1455,13 +1450,13 @@ class ChunkRegionSearch(object):
|
|||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_list[end_idx]
|
||||
end_node = self.index_tracer.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
||||
for end_dim, _ in enumerate(end_trace["idx"]):
|
||||
if len(start_traces) > 1:
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||
for start_dim, _ in enumerate(start_trace["idx"]):
|
||||
# dim size cannot be 1
|
||||
if (
|
||||
_get_node_shape(end_node)[end_dim] == 1
|
||||
|
@ -1494,7 +1489,7 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.node_list):
|
||||
for _, n in enumerate(self.index_tracer.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(
|
||||
|
@ -1507,8 +1502,8 @@ class ChunkRegionSearch(object):
|
|||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if _is_non_compute_node(
|
||||
self.node_list[start_idx]
|
||||
) or _is_non_compute_node(self.node_list[end_idx]):
|
||||
self.index_tracer.node_list[start_idx]
|
||||
) or _is_non_compute_node(self.index_tracer.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
|
@ -1577,7 +1572,9 @@ class ChunkRegionSearch(object):
|
|||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
|
||||
) = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list
|
||||
)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
|
@ -1590,7 +1587,9 @@ class ChunkRegionSearch(object):
|
|||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos)
|
||||
) = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, chunk_infos
|
||||
)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
return chunk_infos
|
||||
|
|
Loading…
Reference in New Issue