add comments

pull/2364/head
oahzxl 2 years ago
parent 19cc64b1d3
commit 212b5b1b5f

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Dict, Iterable, List, Tuple
import torch
from torch.fx.graph import (
@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod
def emit_code_with_chunk(
body,
nodes,
body: List[str],
nodes: Iterable[Node],
emit_node_func,
delete_unused_value_func,
search_chunk: SearchChunk,
chunk_infos,
chunk_infos: List,
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and replace inputs
and outputs of regions with chunked variables.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
search_chunk: the class to search all chunks
chunk_infos: store all information about all chunks.
"""
node_list = list(nodes)
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
# chunk region
chunk_starts = [i["region"][0] for i in chunk_infos]
chunk_ends = [i["region"][1] for i in chunk_infos]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
# chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
@ -170,6 +175,7 @@ def emit_code_with_chunk(
while node_idx < len(node_list):
node = node_list[node_idx]
# if is chunk start, generate for loop start
if node_idx in chunk_starts:
within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
@ -203,6 +209,7 @@ def emit_code_with_chunk(
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(

@ -115,4 +115,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_autochunk_codegen(0, 32, 64, None)
_test_autochunk_codegen(0, 32, 64, 25)

Loading…
Cancel
Save