support output

pull/2364/head
oahzxl 2022-12-13 11:00:51 +08:00
parent cda3e8572a
commit de65e6c3e8
1 changed files with 12 additions and 3 deletions

View File

@ -56,6 +56,14 @@ def _is_non_compute_node_except_placeholder(node):
return False
def _is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
return True
return False
class FlowTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
@ -1083,13 +1091,14 @@ class MemoryEstimator(object):
i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes]
)
chunk_within = False
chunk_region_idx = 0
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in start_nodes:
chunk_within = True
chunk_region_idx = start_nodes.index(idx)
chunk_ratio = self._get_chunk_ratio(
node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx]
)
@ -1149,7 +1158,7 @@ class MemoryEstimator(object):
)
chunk_within = False
chunk_ratio = 1
chunk_region_idx += 1
chunk_region_idx = None
act_memory_after_node_log.append(act_memory)
active_node_list_log.append(copy.deepcopy(active_node_list))
@ -1467,7 +1476,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
if (
output_node not in nodes
and node not in output_nodes
and not _is_non_compute_node_except_placeholder(output_node)
and not _is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node)