mirror of https://github.com/hpcaitech/ColossalAI
add doc for search
parent
a68d240ed5
commit
065f0b4c27
|
@ -1,5 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Dict, Iterable, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
@ -136,7 +136,24 @@ class SearchChunk(object):
|
||||||
chunk_region_end = region[0] - 1
|
chunk_region_end = region[0] - 1
|
||||||
return chunk_region_start, chunk_region_end
|
return chunk_region_start, chunk_region_end
|
||||||
|
|
||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||||
|
"""
|
||||||
|
Find chunk info for a region.
|
||||||
|
|
||||||
|
We are given the region start and region end, and need to find out all chunk info for it.
|
||||||
|
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||||
|
which is linked in a flow and not computed.
|
||||||
|
If found, we then search flow in the whole region to find out all chunk infos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_trace (List): node's input trace in region
|
||||||
|
output_trace (List): node's output trace in region
|
||||||
|
start_idx (int): region start node index
|
||||||
|
end_idx (int): region end node index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_infos: possible regions found
|
||||||
|
"""
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
end_node = self.trace_index.node_list[end_idx]
|
end_node = self.trace_index.node_list[end_idx]
|
||||||
|
@ -174,7 +191,19 @@ class SearchChunk(object):
|
||||||
chunk_infos.append(chunk_info)
|
chunk_infos.append(chunk_info)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
def _search_possible_chunk_regions(
|
||||||
|
self, max_chunk_region: Tuple, peak_node: Node
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
Search every possible region within the max chunk region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_chunk_region (Tuple)
|
||||||
|
peak_node (Node): peak memory node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
possible_chunk_region (List)
|
||||||
|
"""
|
||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
||||||
input_trace = [] # trace of a node's input nodes
|
input_trace = [] # trace of a node's input nodes
|
||||||
|
@ -196,17 +225,39 @@ class SearchChunk(object):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# select free dim
|
# select free dim
|
||||||
chunk_info = self._find_free_dim(
|
chunk_info = self._find_chunk_info(
|
||||||
input_trace, output_trace, start_idx, end_idx
|
input_trace, output_trace, start_idx, end_idx
|
||||||
)
|
)
|
||||||
if len(chunk_info) > 0:
|
if len(chunk_info) > 0:
|
||||||
possible_chunk_region.extend(chunk_info)
|
possible_chunk_region.extend(chunk_info)
|
||||||
return possible_chunk_region
|
return possible_chunk_region
|
||||||
|
|
||||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
def _step_search(
|
||||||
|
self,
|
||||||
|
mem_peak: List[float],
|
||||||
|
active_node: List[List[Node]],
|
||||||
|
chunk_infos: List[Dict],
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Find one chunk region
|
||||||
|
|
||||||
|
The chunk search is as follows:
|
||||||
|
1. find the peak memory node
|
||||||
|
2. find the max chunk region according to the peak memory node
|
||||||
|
3. find all possible chunk regions in the max chunk region
|
||||||
|
4. find the best chunk region for current status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mem_peak (List): peak memory for every node
|
||||||
|
active_node (List[List[Node]]): active node for every node
|
||||||
|
chunk_infos (List[Dict]): all chunk info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
best_chunk_region (Dict)
|
||||||
|
"""
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
max_chunk_region = self._search_max_chunk_region(
|
max_chunk_region = self._search_max_chunk_region(
|
||||||
active_node, peak_node, chunk_regions
|
active_node, peak_node, chunk_infos
|
||||||
)
|
)
|
||||||
if max_chunk_region == None:
|
if max_chunk_region == None:
|
||||||
return None
|
return None
|
||||||
|
@ -214,7 +265,7 @@ class SearchChunk(object):
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||||
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
|
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
)
|
)
|
||||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||||
return best_chunk_region
|
return best_chunk_region
|
||||||
|
@ -225,7 +276,16 @@ class SearchChunk(object):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search_region(self):
|
def search_region(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Search all chunk regions:
|
||||||
|
1. Estimate current memory
|
||||||
|
2. Find best chunk for current memory
|
||||||
|
3. goto 1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_infos (Dict)
|
||||||
|
"""
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
(
|
(
|
||||||
init_mem_peak,
|
init_mem_peak,
|
||||||
|
|
Loading…
Reference in New Issue