mirror of https://github.com/hpcaitech/ColossalAI
refactor structure
parent
71e72c4890
commit
27ab524096
|
@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
|||
|
||||
def emit_code_with_chunk(
|
||||
body,
|
||||
ckpt_func,
|
||||
nodes,
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
meta_nodes,
|
||||
meta_graph,
|
||||
max_memory=None,
|
||||
chunk_region_search,
|
||||
chunk_infos
|
||||
):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
|
@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
|
|||
"""
|
||||
node_list = list(nodes)
|
||||
|
||||
# find the chunk regions
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
|
||||
chunk_search = chunk_region_search.search_region()
|
||||
|
||||
chunk_regions = [i["region"] for i in chunk_search]
|
||||
chunk_regions = [i["region"] for i in chunk_infos]
|
||||
chunk_starts = [i[0] for i in chunk_regions]
|
||||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
|
||||
chunk_inputs = [i["inputs"] for i in chunk_search]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
|
||||
node_idx = 0
|
||||
|
@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
|
|||
chunk_inputs[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_search[region_idx]["chunk_size"],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
|
|||
# ones like
|
||||
if "ones_like" in node.name:
|
||||
meta_node = chunk_region_search.index_tracer.node_list[node_idx]
|
||||
chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
||||
"chunk_dim"
|
||||
]
|
||||
if _get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (
|
||||
source_node not in chunk_search[region_idx]["node_chunk_dim"]
|
||||
or chunk_search[region_idx]["node_chunk_dim"][source_node][
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
|
||||
"chunk_dim"
|
||||
]
|
||||
is None
|
||||
|
@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
|
|||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||
)
|
||||
body[-1] = _replace_reshape_size(
|
||||
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
|
||||
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
|
|||
self.meta_graph = meta_graph
|
||||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
|
||||
self.chunk_infos = self.chunk_region_search.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
self, nodes, root_module: str, namespace: _Namespace
|
||||
|
@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
|
|||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(
|
||||
body,
|
||||
ckpt_func,
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.meta_node,
|
||||
self.meta_graph,
|
||||
self.max_memory,
|
||||
self.chunk_region_search,
|
||||
self.chunk_infos
|
||||
)
|
||||
|
||||
if len(body) == 0:
|
|
@ -3,13 +3,13 @@ import time
|
|||
import torch
|
||||
import torch.fx
|
||||
|
||||
from chunk_codegen import ChunkCodeGen
|
||||
from autochunk.chunk_codegen import ChunkCodeGen
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from evoformer.evoformer import evoformer_base
|
||||
from openfold.evoformer import EvoformerBlock
|
||||
from autochunk.evoformer.evoformer import evoformer_base
|
||||
from autochunk.openfold.evoformer import EvoformerBlock
|
||||
|
||||
|
||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||
|
@ -94,23 +94,23 @@ def _build_openfold():
|
|||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 256
|
||||
pair_len = 2048
|
||||
pair_len = 1024
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
max_memory = 10000 # MB fit memory mode
|
||||
# max_memory = None # min memory mode
|
||||
# max_memory = 10000 # MB fit memory mode
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||
|
||||
# build openfold
|
||||
chunk_size = 64
|
||||
openfold = _build_openfold()
|
||||
# openfold = _build_openfold()
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "base")
|
||||
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
# _benchmark_evoformer(model, node, pair, "base")
|
||||
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from evoformer.evoformer import evoformer_base
|
||||
from chunk_codegen import ChunkCodeGen
|
||||
from autochunk.evoformer.evoformer import evoformer_base
|
||||
from autochunk.chunk_codegen import ChunkCodeGen
|
||||
with_codegen = True
|
||||
|
||||
|
Loading…
Reference in New Issue