|
|
|
@ -1,4 +1,7 @@
|
|
|
|
|
import copy
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Tuple
|
|
|
|
|
|
|
|
|
|
from torch.fx.node import Node
|
|
|
|
|
|
|
|
|
|
from .estimate_memory import EstimateMemory
|
|
|
|
|
from .reorder_graph import ReorderGraph
|
|
|
|
@ -13,6 +16,34 @@ from .utils import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SearchChunk(object):
|
|
|
|
|
"""
|
|
|
|
|
This is the core class for AutoChunk.
|
|
|
|
|
|
|
|
|
|
It defines the framework of the strategy of AutoChunk.
|
|
|
|
|
Chunks will be selected one by one utill search stops.
|
|
|
|
|
|
|
|
|
|
The chunk search is as follows:
|
|
|
|
|
1. find the peak memory node
|
|
|
|
|
2. find the max chunk region according to the peak memory node
|
|
|
|
|
3. find all possible chunk regions in the max chunk region
|
|
|
|
|
4. find the best chunk region for current status
|
|
|
|
|
5. goto 1
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
gm: graph model
|
|
|
|
|
print_mem (bool): print estimated memory
|
|
|
|
|
trace_index: trace the flow of every dim of every node to find all free dims
|
|
|
|
|
trace_flow: determine the region chunk strategy
|
|
|
|
|
reorder_graph: reorder nodes to improve chunk efficiency
|
|
|
|
|
estimate_memory: estimate memory with chunk
|
|
|
|
|
select_chunk: select the best chunk region
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
gm: graph model
|
|
|
|
|
max_memory (int): max memory in MB
|
|
|
|
|
print_mem (bool): print estimated memory
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
|
|
|
|
self.gm = gm
|
|
|
|
|
self.print_mem = print_mem
|
|
|
|
@ -33,24 +64,37 @@ class SearchChunk(object):
|
|
|
|
|
max_idx = mem_peak.index(max_value)
|
|
|
|
|
return max_idx
|
|
|
|
|
|
|
|
|
|
def _get_free_var(self):
|
|
|
|
|
def _get_free_var_idx(self) -> List:
|
|
|
|
|
"""
|
|
|
|
|
Get free var index
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
free_var_idx (List): all indexs of free vars
|
|
|
|
|
"""
|
|
|
|
|
free_var_idx = []
|
|
|
|
|
for idx, n in enumerate(self.trace_index.node_list):
|
|
|
|
|
if n.op == "placeholder":
|
|
|
|
|
free_var_idx.append(idx)
|
|
|
|
|
return free_var_idx
|
|
|
|
|
|
|
|
|
|
def _get_min_free_var(self, active_node_list, free_vars):
|
|
|
|
|
min_len = 999
|
|
|
|
|
for idx, n in enumerate(active_node_list):
|
|
|
|
|
if idx in free_vars:
|
|
|
|
|
continue
|
|
|
|
|
if len(n) < min_len:
|
|
|
|
|
min_len = len(n)
|
|
|
|
|
return min_len
|
|
|
|
|
def _search_max_chunk_region(
|
|
|
|
|
self, active_node: List, peak_node: Node, chunk_regions: List
|
|
|
|
|
) -> Tuple:
|
|
|
|
|
"""
|
|
|
|
|
Search max chunk region according to peak memory node
|
|
|
|
|
|
|
|
|
|
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
|
|
|
|
free_vars = self._get_free_var()
|
|
|
|
|
Chunk region starts extending from the peak node, stops where free var num is min
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
active_node (List): active node status for every node
|
|
|
|
|
peak_node (Node): peak memory node
|
|
|
|
|
chunk_regions (List): chunk region info
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
chunk_region_start (int)
|
|
|
|
|
chunk_region_end (int)
|
|
|
|
|
"""
|
|
|
|
|
free_vars = self._get_free_var_idx()
|
|
|
|
|
free_var_num = len(free_vars)
|
|
|
|
|
active_node_num = [len(i) for i in active_node]
|
|
|
|
|
min_active_node_num = min(active_node_num[free_var_num:])
|
|
|
|
@ -92,16 +136,6 @@ class SearchChunk(object):
|
|
|
|
|
chunk_region_end = region[0] - 1
|
|
|
|
|
return chunk_region_start, chunk_region_end
|
|
|
|
|
|
|
|
|
|
def _is_not_compute(self, trace, chunk_range, dim_idx):
|
|
|
|
|
if trace["idx"][dim_idx] not in trace["compute"]:
|
|
|
|
|
return True
|
|
|
|
|
if trace["idx"][dim_idx] in trace["compute"] and all(
|
|
|
|
|
i < chunk_range[0] or i > chunk_range[1]
|
|
|
|
|
for i in trace["compute"][trace["idx"][dim_idx]]
|
|
|
|
|
):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
|
|
|
|
start_traces = input_trace[start_idx]
|
|
|
|
|
end_trace = output_trace[end_idx]
|
|
|
|
|