mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
4f5e105af3
commit
fa5e6fbf96
|
@ -65,9 +65,8 @@ def _is_non_compute_node_except_placeholder_output(node):
|
|||
|
||||
|
||||
class IndexTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
def __init__(self, node_list) -> None:
|
||||
self.node_list = node_list
|
||||
self.idx_trace_list = self._init_idx_trace_list()
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
|
@ -797,9 +796,7 @@ class IndexTracer(object):
|
|||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def flow_search(
|
||||
self, start_idx, start_dim, end_idx, end_dim
|
||||
):
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
|
@ -819,12 +816,8 @@ class IndexTracer(object):
|
|||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = self._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
|
@ -965,9 +958,7 @@ class IndexTracer(object):
|
|||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: _find_idx_by_name(x.name, self.node_list)
|
||||
)
|
||||
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list))
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
|
@ -1295,7 +1286,9 @@ class ChunkRegionSearch(object):
|
|||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.index_tracer = IndexTracer(gm)
|
||||
self.index_tracer = IndexTracer(
|
||||
self.node_list
|
||||
) # node list shared in index tracer
|
||||
self.index_tracer.trace_index()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
|
||||
|
|
Loading…
Reference in New Issue