|
|
|
@ -1,4 +1,7 @@
|
|
|
|
|
import copy |
|
|
|
|
from typing import Any, Dict, Iterable, List, Tuple |
|
|
|
|
|
|
|
|
|
from torch.fx.node import Node |
|
|
|
|
|
|
|
|
|
from .estimate_memory import EstimateMemory |
|
|
|
|
from .reorder_graph import ReorderGraph |
|
|
|
@ -13,6 +16,34 @@ from .utils import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SearchChunk(object): |
|
|
|
|
""" |
|
|
|
|
This is the core class for AutoChunk. |
|
|
|
|
|
|
|
|
|
It defines the framework of the strategy of AutoChunk. |
|
|
|
|
Chunks will be selected one by one utill search stops. |
|
|
|
|
|
|
|
|
|
The chunk search is as follows: |
|
|
|
|
1. find the peak memory node |
|
|
|
|
2. find the max chunk region according to the peak memory node |
|
|
|
|
3. find all possible chunk regions in the max chunk region |
|
|
|
|
4. find the best chunk region for current status |
|
|
|
|
5. goto 1 |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
|
gm: graph model |
|
|
|
|
print_mem (bool): print estimated memory |
|
|
|
|
trace_index: trace the flow of every dim of every node to find all free dims |
|
|
|
|
trace_flow: determine the region chunk strategy |
|
|
|
|
reorder_graph: reorder nodes to improve chunk efficiency |
|
|
|
|
estimate_memory: estimate memory with chunk |
|
|
|
|
select_chunk: select the best chunk region |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
gm: graph model |
|
|
|
|
max_memory (int): max memory in MB |
|
|
|
|
print_mem (bool): print estimated memory |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None: |
|
|
|
|
self.gm = gm |
|
|
|
|
self.print_mem = print_mem |
|
|
|
@ -33,24 +64,37 @@ class SearchChunk(object):
|
|
|
|
|
max_idx = mem_peak.index(max_value) |
|
|
|
|
return max_idx |
|
|
|
|
|
|
|
|
|
def _get_free_var(self): |
|
|
|
|
def _get_free_var_idx(self) -> List: |
|
|
|
|
""" |
|
|
|
|
Get free var index |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
free_var_idx (List): all indexs of free vars |
|
|
|
|
""" |
|
|
|
|
free_var_idx = [] |
|
|
|
|
for idx, n in enumerate(self.trace_index.node_list): |
|
|
|
|
if n.op == "placeholder": |
|
|
|
|
free_var_idx.append(idx) |
|
|
|
|
return free_var_idx |
|
|
|
|
|
|
|
|
|
def _get_min_free_var(self, active_node_list, free_vars): |
|
|
|
|
min_len = 999 |
|
|
|
|
for idx, n in enumerate(active_node_list): |
|
|
|
|
if idx in free_vars: |
|
|
|
|
continue |
|
|
|
|
if len(n) < min_len: |
|
|
|
|
min_len = len(n) |
|
|
|
|
return min_len |
|
|
|
|
def _search_max_chunk_region( |
|
|
|
|
self, active_node: List, peak_node: Node, chunk_regions: List |
|
|
|
|
) -> Tuple: |
|
|
|
|
""" |
|
|
|
|
Search max chunk region according to peak memory node |
|
|
|
|
|
|
|
|
|
Chunk region starts extending from the peak node, stops where free var num is min |
|
|
|
|
|
|
|
|
|
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): |
|
|
|
|
free_vars = self._get_free_var() |
|
|
|
|
Args: |
|
|
|
|
active_node (List): active node status for every node |
|
|
|
|
peak_node (Node): peak memory node |
|
|
|
|
chunk_regions (List): chunk region info |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
chunk_region_start (int) |
|
|
|
|
chunk_region_end (int) |
|
|
|
|
""" |
|
|
|
|
free_vars = self._get_free_var_idx() |
|
|
|
|
free_var_num = len(free_vars) |
|
|
|
|
active_node_num = [len(i) for i in active_node] |
|
|
|
|
min_active_node_num = min(active_node_num[free_var_num:]) |
|
|
|
@ -92,16 +136,6 @@ class SearchChunk(object):
|
|
|
|
|
chunk_region_end = region[0] - 1 |
|
|
|
|
return chunk_region_start, chunk_region_end |
|
|
|
|
|
|
|
|
|
def _is_not_compute(self, trace, chunk_range, dim_idx): |
|
|
|
|
if trace["idx"][dim_idx] not in trace["compute"]: |
|
|
|
|
return True |
|
|
|
|
if trace["idx"][dim_idx] in trace["compute"] and all( |
|
|
|
|
i < chunk_range[0] or i > chunk_range[1] |
|
|
|
|
for i in trace["compute"][trace["idx"][dim_idx]] |
|
|
|
|
): |
|
|
|
|
return True |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): |
|
|
|
|
start_traces = input_trace[start_idx] |
|
|
|
|
end_trace = output_trace[end_idx] |
|
|
|
|