mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support parsing blocks (#2506)
parent
35c0c0006e
commit
c04f183237
|
@ -22,7 +22,7 @@ if CODEGEN_AVAILABLE:
|
|||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||
|
@ -276,11 +276,17 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||
def __init__(self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False) -> None:
|
||||
super().__init__()
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
if print_progress:
|
||||
get_logger().info("AutoChunk start codegen")
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
|
|
|
@ -43,6 +43,8 @@ class EstimateMemory(object):
|
|||
delete_node = []
|
||||
if user.op not in ("output",):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
if to_keep is not None:
|
||||
keep_list = []
|
||||
for n in nodes_to_delete:
|
||||
|
@ -135,6 +137,8 @@ class EstimateMemory(object):
|
|||
if user.op in ("placeholder", "output"):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
delete_size = 0
|
||||
for n in nodes_to_delete:
|
||||
if n.name in chunk_inputs_names:
|
||||
|
@ -294,3 +298,26 @@ class EstimateMemory(object):
|
|||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||
|
||||
def get_active_nodes(self, node_list: List) -> List:
|
||||
"""
|
||||
Get active nodes for every node
|
||||
|
||||
Args:
|
||||
node_list (List): _description_
|
||||
|
||||
Returns:
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
for _, node in enumerate(node_list):
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
return active_node_list_log
|
||||
|
|
|
@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
|
|||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
|
@ -40,14 +40,14 @@ class SearchChunk(object):
|
|||
print_mem (bool): print estimated memory
|
||||
"""
|
||||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.print_progress = print_progress
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.trace_indice.trace_indice()
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
|
@ -55,7 +55,33 @@ class SearchChunk(object):
|
|||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
def _init_trace(self) -> None:
|
||||
"""
|
||||
find the max trace range for every node
|
||||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1]
|
||||
if cur_node_idx == len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
# nothing to limit for the first range
|
||||
max_chunk_region_list = max_chunk_region_list[1:]
|
||||
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
||||
|
||||
# set trace range and do the trace
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start tracing indice")
|
||||
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
||||
self.trace_indice.trace_indice()
|
||||
|
||||
def _find_peak_node(self, mem_peak: List) -> int:
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
|
@ -73,7 +99,7 @@ class SearchChunk(object):
|
|||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
|
@ -81,7 +107,7 @@ class SearchChunk(object):
|
|||
|
||||
Args:
|
||||
active_node (List): active node status for every node
|
||||
peak_node (Node): peak memory node
|
||||
peak_node_idx (int): peak memory node idx
|
||||
chunk_regions (List): chunk region infos
|
||||
|
||||
Returns:
|
||||
|
@ -97,7 +123,7 @@ class SearchChunk(object):
|
|||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node, -1, -1):
|
||||
for i in range(peak_node_idx, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
|
@ -107,21 +133,23 @@ class SearchChunk(object):
|
|||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node, len(active_node)):
|
||||
for i in range(peak_node_idx, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
|
@ -154,6 +182,9 @@ class SearchChunk(object):
|
|||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
|
@ -253,6 +284,9 @@ class SearchChunk(object):
|
|||
Returns:
|
||||
chunk_infos (Dict)
|
||||
"""
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start searching chunk regions")
|
||||
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
|
@ -272,6 +306,11 @@ class SearchChunk(object):
|
|||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
|
|
|
@ -281,7 +281,10 @@ class TraceFlow(object):
|
|||
if chunk_dim is not None:
|
||||
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
|
||||
input_dict[user_idx] = [None]
|
||||
else:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
return None, None
|
||||
if len(input_dict) == 0:
|
||||
|
|
|
@ -33,6 +33,8 @@ class TraceIndice(object):
|
|||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self):
|
||||
indice_trace_list = []
|
||||
|
@ -48,6 +50,10 @@ class TraceIndice(object):
|
|||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
|
||||
self.trace_range = trace_range
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
@ -493,6 +499,9 @@ class TraceIndice(object):
|
|||
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
|
||||
for _ in range(new_dim_num):
|
||||
self._del_dim(node_idx, 0)
|
||||
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
|
||||
for _ in range(delete_dim_num):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
for _, node_arg in enumerate(node_args):
|
||||
|
@ -513,6 +522,9 @@ class TraceIndice(object):
|
|||
elif "None" == node_arg_str:
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
new_idx_count += 1
|
||||
elif "0" == node_arg_str:
|
||||
self._del_dim(node_idx, new_idx_count)
|
||||
origin_idx_count += 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -596,6 +608,37 @@ class TraceIndice(object):
|
|||
}
|
||||
self.indice_view_list[node] = view_dict
|
||||
|
||||
def _clear_trace(self, node_idx: int) -> None:
|
||||
"""
|
||||
clear too far trace to speed up computation
|
||||
"""
|
||||
trace_range = None
|
||||
for i in range(len(self.trace_range)):
|
||||
if self.trace_range[i][1] == node_idx:
|
||||
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
|
||||
break
|
||||
if self.trace_range[i][1] > node_idx:
|
||||
break
|
||||
if trace_range is None:
|
||||
return
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self):
|
||||
for idx, node in enumerate(self.node_list):
|
||||
if node.op == "placeholder":
|
||||
|
@ -655,3 +698,6 @@ class TraceIndice(object):
|
|||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
# limit trace range
|
||||
self._clear_trace(idx)
|
||||
|
|
|
@ -2,6 +2,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def get_logger():
|
||||
return logger
|
||||
|
||||
|
||||
def flat_list(inputs: Any) -> List:
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerStack
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# model = model.cuda()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# node_mask1 = node_mask.clone()
|
||||
# pair_mask1 = pair_mask.clone()
|
||||
# gm(node1, pair1, node_mask1, pair_mask1, None)
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair, node_mask, pair_mask, None)
|
||||
fx_out = gm(node, pair, node_mask, pair_mask, None)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerStack(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
c_s=384,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
no_blocks=2, # 48
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.25,
|
||||
blocks_per_ckpt=None,
|
||||
inf=1000000000.0,
|
||||
eps=1e-08,
|
||||
clear_cache_between_blocks=False,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = _build_openfold()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"),
|
||||
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None)
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_evoformer_stack_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_evoformer_stack_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_evoformer_stack_codegen(0, 32, 64, None)
|
Loading…
Reference in New Issue