[autochunk] support parsing blocks (#2506)

pull/2509/head
oahzxl 2023-01-20 11:18:17 +08:00 committed by GitHub
parent 35c0c0006e
commit c04f183237
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 314 additions and 22 deletions

View File

@ -22,7 +22,7 @@ if CODEGEN_AVAILABLE:
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
@ -276,11 +276,17 @@ if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None, print_mem=False):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
super().__init__()
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
if print_progress:
get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []

View File

@ -43,6 +43,8 @@ class EstimateMemory(object):
delete_node = []
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
if to_keep is not None:
keep_list = []
for n in nodes_to_delete:
@ -135,6 +137,8 @@ class EstimateMemory(object):
if user.op in ("placeholder", "output"):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
delete_size = 0
for n in nodes_to_delete:
if n.name in chunk_inputs_names:
@ -294,3 +298,26 @@ class EstimateMemory(object):
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
def get_active_nodes(self, node_list: List) -> List:
"""
Get active nodes for every node
Args:
node_list (List): _description_
Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
for _, node in enumerate(node_list):
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
active_node_list_log.append(copy.deepcopy(active_node_list))
return active_node_list_log

View File

@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@ -40,14 +40,14 @@ class SearchChunk(object):
print_mem (bool): print estimated memory
"""
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.trace_indice.trace_indice()
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
@ -55,7 +55,33 @@ class SearchChunk(object):
max_memory=max_memory,
)
def _find_peak_node(self, mem_peak):
def _init_trace(self) -> None:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.trace_indice()
def _find_peak_node(self, mem_peak: List) -> int:
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
@ -73,7 +99,7 @@ class SearchChunk(object):
free_var_idx.append(idx)
return free_var_idx
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
@ -81,7 +107,7 @@ class SearchChunk(object):
Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
peak_node_idx (int): peak memory node idx
chunk_regions (List): chunk region infos
Returns:
@ -97,7 +123,7 @@ class SearchChunk(object):
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node, -1, -1):
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
@ -107,21 +133,23 @@ class SearchChunk(object):
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node, len(active_node)):
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_end = i
break
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
@ -154,6 +182,9 @@ class SearchChunk(object):
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
@ -253,6 +284,9 @@ class SearchChunk(object):
Returns:
chunk_infos (Dict)
"""
if self.print_progress:
get_logger().info("AutoChunk start searching chunk regions")
chunk_infos = []
(
init_mem_peak,
@ -272,6 +306,11 @@ class SearchChunk(object):
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:

View File

@ -281,7 +281,10 @@ class TraceFlow(object):
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
input_dict[user_idx] = [None]
else:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None, None
if len(input_dict) == 0:

View File

@ -33,6 +33,8 @@ class TraceIndice(object):
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
indice_trace_list = []
@ -48,6 +50,10 @@ class TraceIndice(object):
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
"""
Update the count and return it. To record the idx number.
@ -493,6 +499,9 @@ class TraceIndice(object):
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
@ -513,6 +522,9 @@ class TraceIndice(object):
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()
@ -596,6 +608,37 @@ class TraceIndice(object):
}
self.indice_view_list[node] = view_dict
def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
@ -655,3 +698,6 @@ class TraceIndice(object):
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
# limit trace range
self._clear_trace(idx)

View File

@ -2,6 +2,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def get_logger():
return logger
def flat_list(inputs: Any) -> List:
"""

View File

@ -0,0 +1,163 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask, None)
fx_out = gm(node, pair, node_mask, pair_mask, None)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerStack(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
c_s=384,
no_heads_msa=8,
no_heads_pair=4,
no_blocks=2, # 48
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
blocks_per_ckpt=None,
inf=1000000000.0,
eps=1e-08,
clear_cache_between_blocks=False,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_stack_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_stack_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_stack_codegen(0, 32, 64, None)