From 212b5b1b5f4f3debf983d8c47c58af507a554be4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:29:33 +0800 Subject: [PATCH] add comments --- colossalai/autochunk/autochunk_codegen.py | 35 +++++++++++-------- .../test_autochunk/test_autochunk_codegen.py | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 9ec59477b..5ef560ac2 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple import torch from torch.fx.graph import ( @@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod def emit_code_with_chunk( - body, - nodes, + body: List[str], + nodes: Iterable[Node], emit_node_func, delete_unused_value_func, search_chunk: SearchChunk, - chunk_infos, + chunk_infos: List, ): - """Emit code with nested activation checkpoint - When we detect some of the node.activation_checkpoint is a List, we will use - this function to emit the activation checkpoint codes. + """ + Emit code with chunk according to chunk_infos. + + It will generate a for loop in chunk regions, and replace inputs + and outputs of regions with chunked variables. Args: body: forward code - ckpt_func: checkpoint functions code nodes: graph.nodes emit_node_func: function to emit node delete_unused_value_func: function to remove the unused value + search_chunk: the class to search all chunks + chunk_infos: store all information about all chunks. """ node_list = list(nodes) - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] + # chunk region + chunk_starts = [i["region"][0] for i in chunk_infos] + chunk_ends = [i["region"][1] for i in chunk_infos] - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + # chunk inputs + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] + # chunk outputs chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] @@ -170,6 +175,7 @@ def emit_code_with_chunk( while node_idx < len(node_list): node = node_list[node_idx] + # if is chunk start, generate for loop start if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) @@ -203,6 +209,7 @@ def emit_code_with_chunk( if node_idx not in chunk_inputs: delete_unused_value_func(node, body, chunk_inputs_names) + # generate chunk region end if node_idx in chunk_ends: body.append( _gen_loop_end( diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index c4f5cda67..53f62077c 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -115,4 +115,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_codegen(0, 32, 64, None) + _test_autochunk_codegen(0, 32, 64, 25)