From 7d4abaa5257758011f0f4ba1c5943f492e650a55 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 10 Jan 2023 09:59:47 +0800 Subject: [PATCH] add doc --- colossalai/autochunk/autochunk_codegen.py | 99 ++++++++++++++++++++--- colossalai/autochunk/estimate_memory.py | 22 ++++- colossalai/autochunk/reorder_graph.py | 8 +- 3 files changed, 113 insertions(+), 16 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 6e0cfb9cb..73b6bf524 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -20,11 +20,22 @@ from .search_chunk import SearchChunk from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape -def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): +def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: + """ + Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :] + + Args: + chunk_dim (int) + chunk_indice_name (str): chunk indice name + shape (List): node shape + + Returns: + new_shape (str): return slice + """ new_shape = "[" - for idx, i in enumerate(shape): + for idx, _ in enumerate(shape): if idx == chunk_dim: - new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) + new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name) else: new_shape += ":" new_shape += ", " @@ -32,7 +43,26 @@ def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): return new_shape -def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): +def _gen_loop_start( + chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2 +) -> str: + """ + Generate chunk loop start + + eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device) + chunk_size = 32 + for chunk_idx in range(0, 100, 32): + ...... + + Args: + chunk_input (List[Node]): chunk input node + chunk_output (Node): chunk output node + chunk_ouput_dim (int): chunk output node chunk dim + chunk_size (int): chunk size. Defaults to 2. + + Returns: + context (str): generated str + """ input_node = chunk_input[0] out_shape = get_node_shape(chunk_output) out_str = str(list(out_shape)) @@ -45,8 +75,28 @@ def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): def _gen_loop_end( - chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list -): + chunk_inputs: List[Node], + chunk_non_compute_inputs: List[Node], + chunk_outputs: Node, + chunk_outputs_dim: int, + node_list: List[Node], +) -> str: + """ + Generate chunk loop end + + eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node + output_node = chunk_result; xx = None; xx = None + + Args: + chunk_inputs (List[Node]): chunk input node + chunk_non_compute_inputs (List[Node]): input node without chunk + chunk_outputs (Node): chunk output node + chunk_outputs_dim (int): chunk output node chunk dim + node_list (List) + + Returns: + context (str): generated str + """ chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape @@ -76,7 +126,10 @@ def _gen_loop_end( return context -def _replace_name(context, name_from, name_to): +def _replace_name(context: str, name_from: str, name_to: str) -> str: + """ + replace node name + """ patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] for p in patterns: source = p[0] + name_from + p[1] @@ -86,7 +139,10 @@ def _replace_name(context, name_from, name_to): return context -def _replace_reshape_size(context, node_name, reshape_size_dict): +def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str: + """ + replace reshape size, some may have changed due to chunk + """ if node_name not in reshape_size_dict: return context for size_name, size_value in reshape_size_dict[node_name].items(): @@ -94,7 +150,17 @@ def _replace_reshape_size(context, node_name, reshape_size_dict): return context -def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body): +def _replace_ones_like( + search_chunk: SearchChunk, + chunk_infos: List[Dict], + region_idx: int, + node_idx: int, + node: Node, + body: List[str], +) -> List[str]: + """ + add chunk slice for new tensor op such as ones like + """ if "ones_like" in node.name: meta_node = search_chunk.trace_indice.node_list[node_idx] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] @@ -114,7 +180,16 @@ def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_ return body -def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body): +def _replace_input_node( + chunk_inputs: List[Node], + region_idx: int, + chunk_inputs_dim: Dict, + node_idx: int, + body: List[str], +) -> List[str]: + """ + add chunk slice for input nodes + """ for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: @@ -138,7 +213,7 @@ def emit_code_with_chunk( """ Emit code with chunk according to chunk_infos. - It will generate a for loop in chunk regions, and + It will generate a for loop in chunk regions, and replace inputs and outputs of regions with chunked variables. Args: @@ -193,7 +268,7 @@ def emit_code_with_chunk( if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body = _replace_input_var( + body = _replace_input_node( chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body ) # ones like diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 90cfd66a0..62b23cf9f 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -15,6 +15,10 @@ from .utils import ( class EstimateMemory(object): + """ + Estimate memory with chunk + """ + def __init__(self) -> None: pass @@ -31,8 +35,6 @@ class EstimateMemory(object): } out_size = activation_size(fwd_out) out_node = [n.name] if out_size > 0 else [] - # if any(i in n.name for i in ['transpose', 'permute', 'view']): - # out_size = 0 return out_size, out_node def _get_output_node_size(self, n): @@ -184,10 +186,24 @@ class EstimateMemory(object): def estimate_chunk_inference_mem( self, - node_list, + node_list: List, chunk_infos=None, print_mem=False, ): + """ + Estimate inference memory with chunk + + Args: + node_list (List): _description_ + chunk_infos (Dict): Chunk information. Defaults to None. + print_mem (bool): Wether to print peak memory of every node. Defaults to False. + + Returns: + act_memory_peak_log (List): peak memory of every node + act_memory_after_node_log (List): memory after excuting every node + active_node_list_log (List): active nodes of every node. active nodes refer to + nodes generated but not deleted. + """ act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py index 2ece0126e..0343e52ee 100644 --- a/colossalai/autochunk/reorder_graph.py +++ b/colossalai/autochunk/reorder_graph.py @@ -3,6 +3,10 @@ from .utils import find_idx_by_name class ReorderGraph(object): + """ + Reorder node list and indice trace list + """ + def __init__(self, trace_indice: TraceIndice) -> None: self.trace_indice = trace_indice self.all_reorder_map = { @@ -60,7 +64,9 @@ class ReorderGraph(object): def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))] + new_idx_trace_list = [ + None for _ in range(len(self.trace_indice.indice_trace_list)) + ] for old_idx, new_idx in reorder_map.items(): new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] self.trace_indice.indice_trace_list = new_idx_trace_list