refactor structure

pull/2364/head
oahzxl 2023-01-06 11:07:57 +08:00
parent 71e72c4890
commit 27ab524096
19 changed files with 29 additions and 34 deletions

View File

@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
def emit_code_with_chunk(
body,
ckpt_func,
nodes,
emit_node_func,
delete_unused_value_func,
meta_nodes,
meta_graph,
max_memory=None,
chunk_region_search,
chunk_infos
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
"""
node_list = list(nodes)
# find the chunk regions
chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
chunk_search = chunk_region_search.search_region()
chunk_regions = [i["region"] for i in chunk_search]
chunk_regions = [i["region"] for i in chunk_infos]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_search]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
chunk_outputs = [i["outputs"][0] for i in chunk_search]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
node_idx = 0
@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
chunk_inputs[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_search[region_idx]["chunk_size"],
chunk_infos[region_idx]["chunk_size"],
)
)
@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
# ones like
if "ones_like" in node.name:
meta_node = chunk_region_search.index_tracer.node_list[node_idx]
chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
"chunk_dim"
]
if _get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (
source_node not in chunk_search[region_idx]["node_chunk_dim"]
or chunk_search[region_idx]["node_chunk_dim"][source_node][
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
"chunk_dim"
]
is None
@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
body[-1], node.args[0].name, node.args[0].name + chunk_slice
)
body[-1] = _replace_reshape_size(
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
)
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
self.meta_graph = meta_graph
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace
@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(
body,
ckpt_func,
nodes,
emit_node,
delete_unused_values,
self.meta_node,
self.meta_graph,
self.max_memory,
self.chunk_region_search,
self.chunk_infos
)
if len(body) == 0:

View File

@ -3,13 +3,13 @@ import time
import torch
import torch.fx
from chunk_codegen import ChunkCodeGen
from autochunk.chunk_codegen import ChunkCodeGen
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
from evoformer.evoformer import evoformer_base
from openfold.evoformer import EvoformerBlock
from autochunk.evoformer.evoformer import evoformer_base
from autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
@ -94,23 +94,23 @@ def _build_openfold():
def benchmark_evoformer():
# init data and model
msa_len = 256
pair_len = 2048
pair_len = 1024
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda()
# build autochunk model
max_memory = 10000 # MB fit memory mode
# max_memory = None # min memory mode
# max_memory = 10000 # MB fit memory mode
max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
# build openfold
chunk_size = 64
openfold = _build_openfold()
# openfold = _build_openfold()
# benchmark
_benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
# _benchmark_evoformer(model, node, pair, "base")
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk")

View File

@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.fx.profiler import MetaTensor
from evoformer.evoformer import evoformer_base
from chunk_codegen import ChunkCodeGen
from autochunk.evoformer.evoformer import evoformer_base
from autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True