mirror of https://github.com/hpcaitech/ColossalAI
remove flow tracer
parent
4d89525fc2
commit
4f5e105af3
171
chunk_codegen.py
171
chunk_codegen.py
|
@ -67,7 +67,7 @@ def _is_non_compute_node_except_placeholder_output(node):
|
|||
class IndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.idx_trace_list = self._init_idx_trace_list()
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
|
@ -75,7 +75,7 @@ class IndexTracer(object):
|
|||
|
||||
def _init_idx_trace_list(self):
|
||||
idx_trace_list = []
|
||||
for n in self.nodes_list:
|
||||
for n in self.node_list:
|
||||
if _get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
"idx": [None for _ in range(len(_get_node_shape(n)))],
|
||||
|
@ -136,7 +136,7 @@ class IndexTracer(object):
|
|||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
|
||||
if init:
|
||||
node_to_trace["source"][node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
|
@ -210,7 +210,7 @@ class IndexTracer(object):
|
|||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
|
@ -224,7 +224,7 @@ class IndexTracer(object):
|
|||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
|
@ -237,7 +237,7 @@ class IndexTracer(object):
|
|||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||
return self.idx_trace_list[node_idx]["idx"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node):
|
||||
|
@ -249,7 +249,7 @@ class IndexTracer(object):
|
|||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||
return self.idx_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_index_as_input(self, node, node_idx, input_node=None):
|
||||
|
@ -262,7 +262,7 @@ class IndexTracer(object):
|
|||
"""
|
||||
if input_node == None:
|
||||
input_node = node.args[0]
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"]
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
|
@ -591,7 +591,7 @@ class IndexTracer(object):
|
|||
]
|
||||
|
||||
def trace_index(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
for idx, node in enumerate(self.node_list):
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_index(node, idx)
|
||||
elif node.op == "call_method":
|
||||
|
@ -655,7 +655,7 @@ class IndexTracer(object):
|
|||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
|
||||
start_node_idx = _find_idx_by_name(start_node.name, self.node_list)
|
||||
end_node_trace = self._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(
|
||||
|
@ -690,14 +690,14 @@ class IndexTracer(object):
|
|||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
||||
node_to_idx = _find_idx_by_name(node_to.name, self.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
|
||||
node_trace_source = self._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(_get_node_shape(node))):
|
||||
if (
|
||||
|
@ -711,11 +711,11 @@ class IndexTracer(object):
|
|||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k])
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.nodes_list[
|
||||
for node in self.node_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
if _is_non_compute_node_except_placeholder(node):
|
||||
|
@ -746,124 +746,11 @@ class IndexTracer(object):
|
|||
else:
|
||||
return True
|
||||
|
||||
|
||||
class FlowTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.flow_trace = {}
|
||||
|
||||
def _add_trace(self, name):
|
||||
self.flow_trace[name] = []
|
||||
|
||||
def _add_node(self, trace_name, node):
|
||||
self.flow_trace[trace_name].append(
|
||||
{"node": node, "inside_depend": [], "outside_depend": []}
|
||||
)
|
||||
|
||||
def _add_inside_depend(self, flow_name, node, inside_depend_node):
|
||||
for i in self.flow_trace[flow_name]:
|
||||
if i["node"] == node:
|
||||
i["inside_depend"].append(inside_depend_node)
|
||||
return
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _add_outside_depend(
|
||||
self, flow_name, node, outside_depend_node, outside_depend_trace
|
||||
):
|
||||
for i in self.flow_trace[flow_name]:
|
||||
if i["node"] == node:
|
||||
i["outside_depend"].append({outside_depend_trace: outside_depend_node})
|
||||
return
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _init_trace(self):
|
||||
for i in self.node_list:
|
||||
if i.op == "placeholder":
|
||||
self._add_trace(i.name)
|
||||
self._add_node(i.name, i)
|
||||
|
||||
def _find_flow_for_node(self, node):
|
||||
if type(self.node_list[0]) != type(node):
|
||||
return None
|
||||
if _is_non_compute_node_except_placeholder(node):
|
||||
return None
|
||||
for name, trace in self.flow_trace.items():
|
||||
for i in trace:
|
||||
if node == i["node"]:
|
||||
return name
|
||||
if any(i in node.name for i in ["ones_like"]):
|
||||
self._add_trace(node.name)
|
||||
self._add_node(node.name, node)
|
||||
return node.name
|
||||
raise RuntimeError("node not found")
|
||||
|
||||
def _find_first_valid_flow(self, flow):
|
||||
for i in flow:
|
||||
if i is not None:
|
||||
return i
|
||||
raise RuntimeError("invalid flow")
|
||||
|
||||
def find_node_flow(self, node):
|
||||
for name, trace in self.flow_trace.items():
|
||||
for i in trace:
|
||||
if node == i["node"]:
|
||||
return name, i
|
||||
raise RuntimeError("invalid node")
|
||||
|
||||
def _get_flow_mix_node(self, node):
|
||||
if _is_non_compute_node(node):
|
||||
return None
|
||||
_, node_trace = self.find_node_flow(node)
|
||||
if len(node_trace["outside_depend"]) == 0:
|
||||
return None
|
||||
elif len(node_trace["outside_depend"]) > 1:
|
||||
raise NotImplementedError
|
||||
vars = list(node_trace["outside_depend"][0].values())[0]
|
||||
return vars
|
||||
|
||||
def _get_same_flow_node(self, node_list, node):
|
||||
name, _ = self.find_node_flow(node)
|
||||
result = []
|
||||
for i in self.flow_trace[name]:
|
||||
if i["node"] in node_list:
|
||||
result.append(i["node"])
|
||||
return result
|
||||
|
||||
def trace_flow(self):
|
||||
# init trace
|
||||
self._init_trace()
|
||||
|
||||
for node in self.node_list:
|
||||
# skip if non compute node
|
||||
if all(
|
||||
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
|
||||
for arg in node.args
|
||||
) or _is_non_compute_node(node):
|
||||
continue
|
||||
|
||||
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
||||
|
||||
node_domin_flow = self._find_first_valid_flow(node_input_flows)
|
||||
self._add_node(node_domin_flow, node)
|
||||
for node_input_flow, arg in zip(node_input_flows, node.args):
|
||||
if node_input_flow is None:
|
||||
continue
|
||||
elif node_input_flow == node_domin_flow:
|
||||
self._add_inside_depend(node_domin_flow, node, arg)
|
||||
else:
|
||||
self._add_outside_depend(
|
||||
node_domin_flow, node, arg, node_input_flow
|
||||
)
|
||||
return self.flow_trace
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node,
|
||||
start_idx,
|
||||
end_idx,
|
||||
inputs,
|
||||
index_tracer,
|
||||
cur_node_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
|
@ -871,7 +758,7 @@ class FlowTracer(object):
|
|||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
||||
arg_idx = _find_idx_by_name(arg_node.name, self.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
@ -911,7 +798,7 @@ class FlowTracer(object):
|
|||
return True
|
||||
|
||||
def flow_search(
|
||||
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
||||
self, start_idx, start_dim, end_idx, end_dim
|
||||
):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
|
@ -920,7 +807,7 @@ class FlowTracer(object):
|
|||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
||||
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
|
@ -930,12 +817,12 @@ class FlowTracer(object):
|
|||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
||||
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = index_tracer._find_compute_trace_from_node(
|
||||
cur_node_compute = self._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = index_tracer._find_source_trace_from_node(
|
||||
cur_node_source = self._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
else:
|
||||
|
@ -953,8 +840,6 @@ class FlowTracer(object):
|
|||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
inputs,
|
||||
index_tracer,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
|
@ -970,7 +855,7 @@ class FlowTracer(object):
|
|||
for arg in arg_list:
|
||||
if not (
|
||||
start_idx
|
||||
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
|
||||
<= _find_idx_by_name(arg.name, self.node_list)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
|
@ -1029,7 +914,7 @@ class FlowTracer(object):
|
|||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
|
||||
key=lambda x: _find_idx_by_name(x.name, self.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
|
@ -1081,7 +966,7 @@ class FlowTracer(object):
|
|||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
|
||||
key=lambda x: _find_idx_by_name(x.name, self.node_list)
|
||||
)
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
|
@ -1226,9 +1111,9 @@ class MemoryEstimator(object):
|
|||
for k, v in input_node_dim.items():
|
||||
# TODO: inherit dim should be list too, int now
|
||||
inherit_dim = self.index_tracer._find_inherit_dim(
|
||||
input_node, v, self.index_tracer.nodes_list[k]
|
||||
input_node, v, self.index_tracer.node_list[k]
|
||||
)
|
||||
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
||||
if k == _find_idx_by_name(node.name, self.index_tracer.node_list):
|
||||
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
||||
return chunk_ratio
|
||||
for dim, source in enumerate(node_source):
|
||||
|
@ -1412,8 +1297,6 @@ class ChunkRegionSearch(object):
|
|||
self.node_list = list(gm.graph.nodes)
|
||||
self.index_tracer = IndexTracer(gm)
|
||||
self.index_tracer.trace_index()
|
||||
self.flow_tracer = FlowTracer(gm)
|
||||
self.flow_tracer.trace_flow()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
|
@ -1517,8 +1400,8 @@ class ChunkRegionSearch(object):
|
|||
):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.flow_tracer.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
chunk_info = self.index_tracer.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim
|
||||
)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue