code style

pull/2364/head
oahzxl 2023-01-06 17:31:59 +08:00
parent a6cdbf9161
commit c3a2bf48b4
5 changed files with 46 additions and 36 deletions

View File

@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes, nodes,
emit_node_func, emit_node_func,
delete_unused_value_func, delete_unused_value_func,
chunk_region_search: SearchChunk, search_chunk: SearchChunk,
chunk_infos, chunk_infos,
): ):
"""Emit code with nested activation checkpoint """Emit code with nested activation checkpoint
@ -133,7 +133,7 @@ def emit_code_with_chunk(
chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list) node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
node_idx = 0 node_idx = 0
region_idx = 0 region_idx = 0
within_chunk_region = False within_chunk_region = False
@ -167,7 +167,7 @@ def emit_code_with_chunk(
) )
# ones like # ones like
if "ones_like" in node.name: if "ones_like" in node.name:
meta_node = chunk_region_search.trace_index.node_list[node_idx] meta_node = search_chunk.trace_index.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
"chunk_dim" "chunk_dim"
] ]
@ -220,10 +220,8 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes) self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions # find the chunk regions
self.chunk_region_search = SearchChunk( self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
meta_graph, max_memory, print_mem self.chunk_infos = self.search_chunk.search_region()
)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code( def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace self, nodes, root_module: str, namespace: _Namespace
@ -458,7 +456,7 @@ if CODEGEN_AVAILABLE:
nodes, nodes,
emit_node, emit_node,
delete_unused_values, delete_unused_values,
self.chunk_region_search, self.search_chunk,
self.chunk_infos, self.chunk_infos,
) )

View File

@ -3,28 +3,31 @@ from .utils import find_idx_by_name
class ReorderGraph(object): class ReorderGraph(object):
def __init__(self, index_tracer: TraceIndex) -> None: def __init__(self, trace_index: TraceIndex) -> None:
self.index_tracer = index_tracer self.trace_index = trace_index
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} self.all_reorder_map = {
i: i for i in range(len(self.trace_index.idx_trace_list))
}
def _get_reorder_map(self, chunk_info): def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} reorder_map = {i: i for i in range(len(self.trace_index.node_list))}
chunk_region_start = chunk_info["region"][0] chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1] chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [ chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes find_idx_by_name(i.name, self.trace_index.node_list)
for i in chunk_prepose_nodes
] ]
# put prepose nodes ahead # put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes): for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx] n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes # put other nodes after prepose nodes
for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]:
if n in chunk_prepose_nodes: if n in chunk_prepose_nodes:
continue continue
n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) n_idx = find_idx_by_name(n.name, self.trace_index.node_list)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos reorder_map[n_idx] = n_idx + pos
@ -50,25 +53,25 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx] self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map): def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.index_tracer.node_list))] new_node_list = [None for _ in range(len(self.trace_index.node_list))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.index_tracer.node_list[old_idx] new_node_list[new_idx] = self.trace_index.node_list[old_idx]
self.index_tracer.node_list = new_node_list self.trace_index.node_list = new_node_list
def _reorder_idx_trace(self, reorder_map): def _reorder_idx_trace(self, reorder_map):
# reorder list # reorder list
new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx]
self.index_tracer.idx_trace_list = new_idx_trace_list self.trace_index.idx_trace_list = new_idx_trace_list
# update compute # update compute
for idx_trace in self.index_tracer.idx_trace_list: for idx_trace in self.trace_index.idx_trace_list:
compute = idx_trace["compute"] compute = idx_trace["compute"]
for dim_compute in compute: for dim_compute in compute:
for idx, i in enumerate(dim_compute): for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i] dim_compute[idx] = reorder_map[i]
# update source # update source
for idx_trace in self.index_tracer.idx_trace_list: for idx_trace in self.trace_index.idx_trace_list:
source = idx_trace["source"] source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source): for dim_idx, dim_source in enumerate(source):
new_dim_source = {} new_dim_source = {}

View File

@ -1,10 +1,10 @@
import copy import copy
from .select_chunk import SelectChunk
from .trace_index import TraceIndex
from .reorder_graph import ReorderGraph
from .estiamte_memory import EstimateMemory from .estiamte_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_index import TraceIndex
from .utils import ( from .utils import (
get_node_shape, get_node_shape,
is_non_compute_node, is_non_compute_node,
@ -22,7 +22,10 @@ class SearchChunk(object):
self.reorder_graph = ReorderGraph(self.trace_index) self.reorder_graph = ReorderGraph(self.trace_index)
self.estimate_memory = EstimateMemory() self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk( self.select_chunk = SelectChunk(
self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory self.trace_index,
self.estimate_memory,
self.reorder_graph,
max_memory=max_memory,
) )
def _find_peak_node(self, mem_peak): def _find_peak_node(self, mem_peak):

View File

@ -1,19 +1,19 @@
from .trace_index import TraceIndex
from .reorder_graph import ReorderGraph
from .estiamte_memory import EstimateMemory from .estiamte_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_index import TraceIndex
from .utils import is_non_compute_node from .utils import is_non_compute_node
class SelectChunk(object): class SelectChunk(object):
def __init__( def __init__(
self, self,
index_tracer: TraceIndex, trace_index: TraceIndex,
memory_estimator: EstimateMemory, estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph, reorder_graph: ReorderGraph,
max_memory=None, max_memory=None,
): ):
self.index_tracer = index_tracer self.index_tracer = trace_index
self.memory_estimator = memory_estimator self.memory_estimator = estimate_memory
self.reorder_graph = reorder_graph self.reorder_graph = reorder_graph
if max_memory is not None: if max_memory is not None:
self.stratge = "fit_memory" self.stratge = "fit_memory"

View File

@ -81,7 +81,9 @@ class TraceFlow(object):
input_dim_after_node = {} input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k]) inherit_dim = self._find_inherit_dim(
input_node, v, self.trace_index.node_list[k]
)
if inherit_dim: if inherit_dim:
input_dim_after_node[k] = inherit_dim input_dim_after_node[k] = inherit_dim
@ -217,7 +219,9 @@ class TraceFlow(object):
for arg in arg_list: for arg in arg_list:
if not ( if not (
start_idx start_idx
<= find_idx_by_name(arg.name, self.trace_index.node_list) <= find_idx_by_name(
arg.name, self.trace_index.node_list
)
< end_idx < end_idx
): ):
continue continue
@ -255,7 +259,9 @@ class TraceFlow(object):
if start_idx <= user_idx <= end_idx: if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"] chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None: if chunk_dim is not None:
user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim] user_source = self.trace_index._find_source_trace_from_node(
user
)[chunk_dim]
if input_node_idx in user_source: if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx] input_dict[user_idx] = user_source[input_node_idx]
else: else: