mirror of https://github.com/hpcaitech/ColossalAI
add doc for search
parent
a68d240ed5
commit
065f0b4c27
|
@ -1,5 +1,5 @@
|
|||
import copy
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
@ -136,7 +136,24 @@ class SearchChunk(object):
|
|||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_index.node_list[end_idx]
|
||||
|
@ -174,7 +191,19 @@ class SearchChunk(object):
|
|||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
def _search_possible_chunk_regions(
|
||||
self, max_chunk_region: Tuple, peak_node: Node
|
||||
) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
Args:
|
||||
max_chunk_region (Tuple)
|
||||
peak_node (Node): peak memory node
|
||||
|
||||
Returns:
|
||||
possible_chunk_region (List)
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
|
@ -196,17 +225,39 @@ class SearchChunk(object):
|
|||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_free_dim(
|
||||
chunk_info = self._find_chunk_info(
|
||||
input_trace, output_trace, start_idx, end_idx
|
||||
)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||
def _step_search(
|
||||
self,
|
||||
mem_peak: List[float],
|
||||
active_node: List[List[Node]],
|
||||
chunk_infos: List[Dict],
|
||||
) -> Dict:
|
||||
"""
|
||||
Find one chunk region
|
||||
|
||||
The chunk search is as follows:
|
||||
1. find the peak memory node
|
||||
2. find the max chunk region according to the peak memory node
|
||||
3. find all possible chunk regions in the max chunk region
|
||||
4. find the best chunk region for current status
|
||||
|
||||
Args:
|
||||
mem_peak (List): peak memory for every node
|
||||
active_node (List[List[Node]]): active node for every node
|
||||
chunk_infos (List[Dict]): all chunk info
|
||||
|
||||
Returns:
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
active_node, peak_node, chunk_regions
|
||||
active_node, peak_node, chunk_infos
|
||||
)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
|
@ -214,7 +265,7 @@ class SearchChunk(object):
|
|||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
||||
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
@ -225,7 +276,16 @@ class SearchChunk(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def search_region(self):
|
||||
def search_region(self) -> Dict:
|
||||
"""
|
||||
Search all chunk regions:
|
||||
1. Estimate current memory
|
||||
2. Find best chunk for current memory
|
||||
3. goto 1
|
||||
|
||||
Returns:
|
||||
chunk_infos (Dict)
|
||||
"""
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
|
|
Loading…
Reference in New Issue