|
|
@ -1,4 +1,4 @@
|
|
|
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|
|
|
from typing import Any, Dict, Iterable, List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch.fx.graph import (
|
|
|
|
from torch.fx.graph import (
|
|
|
@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def emit_code_with_chunk(
|
|
|
|
def emit_code_with_chunk(
|
|
|
|
body,
|
|
|
|
body: List[str],
|
|
|
|
nodes,
|
|
|
|
nodes: Iterable[Node],
|
|
|
|
emit_node_func,
|
|
|
|
emit_node_func,
|
|
|
|
delete_unused_value_func,
|
|
|
|
delete_unused_value_func,
|
|
|
|
search_chunk: SearchChunk,
|
|
|
|
search_chunk: SearchChunk,
|
|
|
|
chunk_infos,
|
|
|
|
chunk_infos: List,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Emit code with nested activation checkpoint
|
|
|
|
"""
|
|
|
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
|
|
|
Emit code with chunk according to chunk_infos.
|
|
|
|
this function to emit the activation checkpoint codes.
|
|
|
|
|
|
|
|
|
|
|
|
It will generate a for loop in chunk regions, and replace inputs
|
|
|
|
|
|
|
|
and outputs of regions with chunked variables.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
body: forward code
|
|
|
|
body: forward code
|
|
|
|
ckpt_func: checkpoint functions code
|
|
|
|
|
|
|
|
nodes: graph.nodes
|
|
|
|
nodes: graph.nodes
|
|
|
|
emit_node_func: function to emit node
|
|
|
|
emit_node_func: function to emit node
|
|
|
|
delete_unused_value_func: function to remove the unused value
|
|
|
|
delete_unused_value_func: function to remove the unused value
|
|
|
|
|
|
|
|
search_chunk: the class to search all chunks
|
|
|
|
|
|
|
|
chunk_infos: store all information about all chunks.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
node_list = list(nodes)
|
|
|
|
node_list = list(nodes)
|
|
|
|
|
|
|
|
|
|
|
|
chunk_regions = [i["region"] for i in chunk_infos]
|
|
|
|
# chunk region
|
|
|
|
chunk_starts = [i[0] for i in chunk_regions]
|
|
|
|
chunk_starts = [i["region"][0] for i in chunk_infos]
|
|
|
|
chunk_ends = [i[1] for i in chunk_regions]
|
|
|
|
chunk_ends = [i["region"][1] for i in chunk_infos]
|
|
|
|
|
|
|
|
|
|
|
|
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
|
|
|
# chunk inputs
|
|
|
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
|
|
|
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
|
|
|
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
|
|
|
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
|
|
|
|
|
|
|
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
|
|
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
|
|
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
|
|
|
j.name for i in chunk_inputs_non_chunk for j in i
|
|
|
|
j.name for i in chunk_inputs_non_chunk for j in i
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# chunk outputs
|
|
|
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
|
|
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
|
|
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
|
|
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
|
|
|
|
|
|
|
|
|
|
@ -170,6 +175,7 @@ def emit_code_with_chunk(
|
|
|
|
while node_idx < len(node_list):
|
|
|
|
while node_idx < len(node_list):
|
|
|
|
node = node_list[node_idx]
|
|
|
|
node = node_list[node_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if is chunk start, generate for loop start
|
|
|
|
if node_idx in chunk_starts:
|
|
|
|
if node_idx in chunk_starts:
|
|
|
|
within_chunk_region = True
|
|
|
|
within_chunk_region = True
|
|
|
|
region_idx = chunk_starts.index(node_idx)
|
|
|
|
region_idx = chunk_starts.index(node_idx)
|
|
|
@ -203,6 +209,7 @@ def emit_code_with_chunk(
|
|
|
|
if node_idx not in chunk_inputs:
|
|
|
|
if node_idx not in chunk_inputs:
|
|
|
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
|
|
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# generate chunk region end
|
|
|
|
if node_idx in chunk_ends:
|
|
|
|
if node_idx in chunk_ends:
|
|
|
|
body.append(
|
|
|
|
body.append(
|
|
|
|
_gen_loop_end(
|
|
|
|
_gen_loop_end(
|
|
|
|