|
|
@ -1,5 +1,5 @@ |
|
|
|
import copy |
|
|
|
import copy |
|
|
|
from typing import Any, Dict, Iterable, List, Tuple |
|
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
|
|
|
|
|
|
from torch.fx.node import Node |
|
|
|
from torch.fx.node import Node |
|
|
|
|
|
|
|
|
|
|
@ -136,7 +136,24 @@ class SearchChunk(object): |
|
|
|
chunk_region_end = region[0] - 1 |
|
|
|
chunk_region_end = region[0] - 1 |
|
|
|
return chunk_region_start, chunk_region_end |
|
|
|
return chunk_region_start, chunk_region_end |
|
|
|
|
|
|
|
|
|
|
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): |
|
|
|
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Find chunk info for a region. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
We are given the region start and region end, and need to find out all chunk info for it. |
|
|
|
|
|
|
|
We first loop every dim of start node and end node, to see if we can find dim pair, |
|
|
|
|
|
|
|
which is linked in a flow and not computed. |
|
|
|
|
|
|
|
If found, we then search flow in the whole region to find out all chunk infos. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
input_trace (List): node's input trace in region |
|
|
|
|
|
|
|
output_trace (List): node's output trace in region |
|
|
|
|
|
|
|
start_idx (int): region start node index |
|
|
|
|
|
|
|
end_idx (int): region end node index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
chunk_infos: possible regions found |
|
|
|
|
|
|
|
""" |
|
|
|
start_traces = input_trace[start_idx] |
|
|
|
start_traces = input_trace[start_idx] |
|
|
|
end_trace = output_trace[end_idx] |
|
|
|
end_trace = output_trace[end_idx] |
|
|
|
end_node = self.trace_index.node_list[end_idx] |
|
|
|
end_node = self.trace_index.node_list[end_idx] |
|
|
@ -174,7 +191,19 @@ class SearchChunk(object): |
|
|
|
chunk_infos.append(chunk_info) |
|
|
|
chunk_infos.append(chunk_info) |
|
|
|
return chunk_infos |
|
|
|
return chunk_infos |
|
|
|
|
|
|
|
|
|
|
|
def _search_possible_chunk_regions(self, max_chunk_region, peak_node): |
|
|
|
def _search_possible_chunk_regions( |
|
|
|
|
|
|
|
self, max_chunk_region: Tuple, peak_node: Node |
|
|
|
|
|
|
|
) -> List: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Search every possible region within the max chunk region. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
max_chunk_region (Tuple) |
|
|
|
|
|
|
|
peak_node (Node): peak memory node |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
possible_chunk_region (List) |
|
|
|
|
|
|
|
""" |
|
|
|
possible_chunk_region = [] |
|
|
|
possible_chunk_region = [] |
|
|
|
output_trace = copy.deepcopy(self.trace_index.idx_trace_list) |
|
|
|
output_trace = copy.deepcopy(self.trace_index.idx_trace_list) |
|
|
|
input_trace = [] # trace of a node's input nodes |
|
|
|
input_trace = [] # trace of a node's input nodes |
|
|
@ -196,17 +225,39 @@ class SearchChunk(object): |
|
|
|
continue |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
# select free dim |
|
|
|
# select free dim |
|
|
|
chunk_info = self._find_free_dim( |
|
|
|
chunk_info = self._find_chunk_info( |
|
|
|
input_trace, output_trace, start_idx, end_idx |
|
|
|
input_trace, output_trace, start_idx, end_idx |
|
|
|
) |
|
|
|
) |
|
|
|
if len(chunk_info) > 0: |
|
|
|
if len(chunk_info) > 0: |
|
|
|
possible_chunk_region.extend(chunk_info) |
|
|
|
possible_chunk_region.extend(chunk_info) |
|
|
|
return possible_chunk_region |
|
|
|
return possible_chunk_region |
|
|
|
|
|
|
|
|
|
|
|
def _step_search(self, mem_peak, active_node, chunk_regions): |
|
|
|
def _step_search( |
|
|
|
|
|
|
|
self, |
|
|
|
|
|
|
|
mem_peak: List[float], |
|
|
|
|
|
|
|
active_node: List[List[Node]], |
|
|
|
|
|
|
|
chunk_infos: List[Dict], |
|
|
|
|
|
|
|
) -> Dict: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Find one chunk region |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The chunk search is as follows: |
|
|
|
|
|
|
|
1. find the peak memory node |
|
|
|
|
|
|
|
2. find the max chunk region according to the peak memory node |
|
|
|
|
|
|
|
3. find all possible chunk regions in the max chunk region |
|
|
|
|
|
|
|
4. find the best chunk region for current status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
mem_peak (List): peak memory for every node |
|
|
|
|
|
|
|
active_node (List[List[Node]]): active node for every node |
|
|
|
|
|
|
|
chunk_infos (List[Dict]): all chunk info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
best_chunk_region (Dict) |
|
|
|
|
|
|
|
""" |
|
|
|
peak_node = self._find_peak_node(mem_peak) |
|
|
|
peak_node = self._find_peak_node(mem_peak) |
|
|
|
max_chunk_region = self._search_max_chunk_region( |
|
|
|
max_chunk_region = self._search_max_chunk_region( |
|
|
|
active_node, peak_node, chunk_regions |
|
|
|
active_node, peak_node, chunk_infos |
|
|
|
) |
|
|
|
) |
|
|
|
if max_chunk_region == None: |
|
|
|
if max_chunk_region == None: |
|
|
|
return None |
|
|
|
return None |
|
|
@ -214,7 +265,7 @@ class SearchChunk(object): |
|
|
|
max_chunk_region, peak_node |
|
|
|
max_chunk_region, peak_node |
|
|
|
) |
|
|
|
) |
|
|
|
best_chunk_region = self.select_chunk._select_best_chunk_region( |
|
|
|
best_chunk_region = self.select_chunk._select_best_chunk_region( |
|
|
|
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak |
|
|
|
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak |
|
|
|
) |
|
|
|
) |
|
|
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) |
|
|
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) |
|
|
|
return best_chunk_region |
|
|
|
return best_chunk_region |
|
|
@ -225,7 +276,16 @@ class SearchChunk(object): |
|
|
|
return True |
|
|
|
return True |
|
|
|
return False |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def search_region(self): |
|
|
|
def search_region(self) -> Dict: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Search all chunk regions: |
|
|
|
|
|
|
|
1. Estimate current memory |
|
|
|
|
|
|
|
2. Find best chunk for current memory |
|
|
|
|
|
|
|
3. goto 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
chunk_infos (Dict) |
|
|
|
|
|
|
|
""" |
|
|
|
chunk_infos = [] |
|
|
|
chunk_infos = [] |
|
|
|
( |
|
|
|
( |
|
|
|
init_mem_peak, |
|
|
|
init_mem_peak, |
|
|
|