Browse Source

add doc for search

pull/2364/head
oahzxl 2 years ago
parent
commit
065f0b4c27
  1. 76
      colossalai/autochunk/search_chunk.py

76
colossalai/autochunk/search_chunk.py

@ -1,5 +1,5 @@
import copy import copy
from typing import Any, Dict, Iterable, List, Tuple from typing import Dict, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
@ -136,7 +136,24 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1 chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx] start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx] end_trace = output_trace[end_idx]
end_node = self.trace_index.node_list[end_idx] end_node = self.trace_index.node_list[end_idx]
@ -174,7 +191,19 @@ class SearchChunk(object):
chunk_infos.append(chunk_info) chunk_infos.append(chunk_info)
return chunk_infos return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region, peak_node): def _search_possible_chunk_regions(
self, max_chunk_region: Tuple, peak_node: Node
) -> List:
"""
Search every possible region within the max chunk region.
Args:
max_chunk_region (Tuple)
peak_node (Node): peak memory node
Returns:
possible_chunk_region (List)
"""
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_index.idx_trace_list) output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
input_trace = [] # trace of a node's input nodes input_trace = [] # trace of a node's input nodes
@ -196,17 +225,39 @@ class SearchChunk(object):
continue continue
# select free dim # select free dim
chunk_info = self._find_free_dim( chunk_info = self._find_chunk_info(
input_trace, output_trace, start_idx, end_idx input_trace, output_trace, start_idx, end_idx
) )
if len(chunk_info) > 0: if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info) possible_chunk_region.extend(chunk_info)
return possible_chunk_region return possible_chunk_region
def _step_search(self, mem_peak, active_node, chunk_regions): def _step_search(
self,
mem_peak: List[float],
active_node: List[List[Node]],
chunk_infos: List[Dict],
) -> Dict:
"""
Find one chunk region
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
Args:
mem_peak (List): peak memory for every node
active_node (List[List[Node]]): active node for every node
chunk_infos (List[Dict]): all chunk info
Returns:
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak) peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region( max_chunk_region = self._search_max_chunk_region(
active_node, peak_node, chunk_regions active_node, peak_node, chunk_infos
) )
if max_chunk_region == None: if max_chunk_region == None:
return None return None
@ -214,7 +265,7 @@ class SearchChunk(object):
max_chunk_region, peak_node max_chunk_region, peak_node
) )
best_chunk_region = self.select_chunk._select_best_chunk_region( best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
) )
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region return best_chunk_region
@ -225,7 +276,16 @@ class SearchChunk(object):
return True return True
return False return False
def search_region(self): def search_region(self) -> Dict:
"""
Search all chunk regions:
1. Estimate current memory
2. Find best chunk for current memory
3. goto 1
Returns:
chunk_infos (Dict)
"""
chunk_infos = [] chunk_infos = []
( (
init_mem_peak, init_mem_peak,

Loading…
Cancel
Save