From a68d240ed56dcd62a0726621c50233f733e79367 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:54:08 +0800 Subject: [PATCH] add doc for search chunk --- colossalai/autochunk/search_chunk.py | 76 ++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 21b967497..613c28454 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -1,4 +1,7 @@ import copy +from typing import Any, Dict, Iterable, List, Tuple + +from torch.fx.node import Node from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph @@ -13,6 +16,34 @@ from .utils import ( class SearchChunk(object): + """ + This is the core class for AutoChunk. + + It defines the framework of the strategy of AutoChunk. + Chunks will be selected one by one utill search stops. + + The chunk search is as follows: + 1. find the peak memory node + 2. find the max chunk region according to the peak memory node + 3. find all possible chunk regions in the max chunk region + 4. find the best chunk region for current status + 5. goto 1 + + Attributes: + gm: graph model + print_mem (bool): print estimated memory + trace_index: trace the flow of every dim of every node to find all free dims + trace_flow: determine the region chunk strategy + reorder_graph: reorder nodes to improve chunk efficiency + estimate_memory: estimate memory with chunk + select_chunk: select the best chunk region + + Args: + gm: graph model + max_memory (int): max memory in MB + print_mem (bool): print estimated memory + """ + def __init__(self, gm, max_memory=None, print_mem=False) -> None: self.gm = gm self.print_mem = print_mem @@ -33,24 +64,37 @@ class SearchChunk(object): max_idx = mem_peak.index(max_value) return max_idx - def _get_free_var(self): + def _get_free_var_idx(self) -> List: + """ + Get free var index + + Returns: + free_var_idx (List): all indexs of free vars + """ free_var_idx = [] for idx, n in enumerate(self.trace_index.node_list): if n.op == "placeholder": free_var_idx.append(idx) return free_var_idx - def _get_min_free_var(self, active_node_list, free_vars): - min_len = 999 - for idx, n in enumerate(active_node_list): - if idx in free_vars: - continue - if len(n) < min_len: - min_len = len(n) - return min_len + def _search_max_chunk_region( + self, active_node: List, peak_node: Node, chunk_regions: List + ) -> Tuple: + """ + Search max chunk region according to peak memory node + + Chunk region starts extending from the peak node, stops where free var num is min - def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): - free_vars = self._get_free_var() + Args: + active_node (List): active node status for every node + peak_node (Node): peak memory node + chunk_regions (List): chunk region info + + Returns: + chunk_region_start (int) + chunk_region_end (int) + """ + free_vars = self._get_free_var_idx() free_var_num = len(free_vars) active_node_num = [len(i) for i in active_node] min_active_node_num = min(active_node_num[free_var_num:]) @@ -92,16 +136,6 @@ class SearchChunk(object): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end - def _is_not_compute(self, trace, chunk_range, dim_idx): - if trace["idx"][dim_idx] not in trace["compute"]: - return True - if trace["idx"][dim_idx] in trace["compute"] and all( - i < chunk_range[0] or i > chunk_range[1] - for i in trace["compute"][trace["idx"][dim_idx]] - ): - return True - return False - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx]