diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 613c28454..ff4c15878 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, Iterable, List, Tuple +from typing import Dict, List, Tuple from torch.fx.node import Node @@ -136,7 +136,24 @@ class SearchChunk(object): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: + """ + Find chunk info for a region. + + We are given the region start and region end, and need to find out all chunk info for it. + We first loop every dim of start node and end node, to see if we can find dim pair, + which is linked in a flow and not computed. + If found, we then search flow in the whole region to find out all chunk infos. + + Args: + input_trace (List): node's input trace in region + output_trace (List): node's output trace in region + start_idx (int): region start node index + end_idx (int): region end node index + + Returns: + chunk_infos: possible regions found + """ start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] end_node = self.trace_index.node_list[end_idx] @@ -174,7 +191,19 @@ class SearchChunk(object): chunk_infos.append(chunk_info) return chunk_infos - def _search_possible_chunk_regions(self, max_chunk_region, peak_node): + def _search_possible_chunk_regions( + self, max_chunk_region: Tuple, peak_node: Node + ) -> List: + """ + Search every possible region within the max chunk region. + + Args: + max_chunk_region (Tuple) + peak_node (Node): peak memory node + + Returns: + possible_chunk_region (List) + """ possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_index.idx_trace_list) input_trace = [] # trace of a node's input nodes @@ -196,17 +225,39 @@ class SearchChunk(object): continue # select free dim - chunk_info = self._find_free_dim( + chunk_info = self._find_chunk_info( input_trace, output_trace, start_idx, end_idx ) if len(chunk_info) > 0: possible_chunk_region.extend(chunk_info) return possible_chunk_region - def _step_search(self, mem_peak, active_node, chunk_regions): + def _step_search( + self, + mem_peak: List[float], + active_node: List[List[Node]], + chunk_infos: List[Dict], + ) -> Dict: + """ + Find one chunk region + + The chunk search is as follows: + 1. find the peak memory node + 2. find the max chunk region according to the peak memory node + 3. find all possible chunk regions in the max chunk region + 4. find the best chunk region for current status + + Args: + mem_peak (List): peak memory for every node + active_node (List[List[Node]]): active node for every node + chunk_infos (List[Dict]): all chunk info + + Returns: + best_chunk_region (Dict) + """ peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( - active_node, peak_node, chunk_regions + active_node, peak_node, chunk_infos ) if max_chunk_region == None: return None @@ -214,7 +265,7 @@ class SearchChunk(object): max_chunk_region, peak_node ) best_chunk_region = self.select_chunk._select_best_chunk_region( - possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak + possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak ) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region @@ -225,7 +276,16 @@ class SearchChunk(object): return True return False - def search_region(self): + def search_region(self) -> Dict: + """ + Search all chunk regions: + 1. Estimate current memory + 2. Find best chunk for current memory + 3. goto 1 + + Returns: + chunk_infos (Dict) + """ chunk_infos = [] ( init_mem_peak,