refactor flow tracer

pull/2364/head
oahzxl 2022-12-21 15:01:03 +08:00
parent d734529a39
commit d361d533e8
2 changed files with 239 additions and 53 deletions

View File

@ -139,7 +139,13 @@ class IndexTracer(object):
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
if init:
node_to_trace["source"][node_to_dim] = {}
node_to_trace["source"][node_to_dim][node_from_idx] = node_from_dim
# add dim to cur new source
if node_from_idx not in node_to_trace["source"][node_to_dim]:
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
else:
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim)
# update inputs source
node_to_trace["source"][node_to_dim].update(
node_from_trace["source"][node_from_dim]
)
@ -654,7 +660,7 @@ class IndexTracer(object):
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and node_dim == start_dim:
if node_idx == start_node_idx and start_dim in node_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx:
@ -694,12 +700,12 @@ class IndexTracer(object):
for node_dim in range(len(_get_node_shape(node))):
if (
input_node_idx in node_trace_source[node_dim]
and node_trace_source[node_dim][input_node_idx] == input_dim
and input_dim in node_trace_source[node_dim][input_node_idx]
):
return node_dim
return None
def check_index_duplicate(self, chunk_infos):
def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
@ -713,17 +719,30 @@ class IndexTracer(object):
if _is_non_compute_node_except_placeholder(node):
continue
count = 0
duplicate_dims = []
node_trace_source = self._find_source_trace_from_node(node)
for node_dim in range(len(_get_node_shape(node))):
duplicate_dim = []
duplicate_flag = False
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] == v:
count += 1
break
if k in input_dim_after_node and input_dim_after_node[k] in v:
duplicate_flag = True
duplicate_dim.append((k, v))
duplicate_dims.append(duplicate_dim)
if duplicate_flag:
count += 1
if count > 1:
return False
return True
if return_dim:
return False, duplicate_dims
else:
return False
if return_dim:
return True, None
else:
return True
@ -857,43 +876,45 @@ class FlowTracer(object):
flow_block = True
return flow_block, chunk_info
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
mix_flow_node = self._get_flow_mix_node(node)
if mix_flow_node is None:
continue
# for idx in range(start_idx, end_idx + 1):
# node = self.node_list[idx]
# mix_flow_node = self._get_flow_mix_node(node)
# if mix_flow_node is None:
# continue
# if there is a flow mix, op must be in [mul, add, matmul]
# element-wise op requires dim to be equal in every dim
if any(n in node.name for n in ["mul", "add"]):
for i in node.args:
if type(i) == type(mix_flow_node) and i != mix_flow_node:
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO: need to move that flow out of the chunk
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
self.node_list[end_idx], end_dim, node
)
if mix_flow_node_dim is None:
flow_block = True
break
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
flow_block = False
for i in self._get_same_flow_node(
chunk_info["inputs"], mix_flow_node
):
chunk_info["inputs"].remove(i)
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_block = True
break
else:
raise NotImplementedError("%s not implemented" % node.name)
if flow_block:
flow_block = True
return flow_block, chunk_info
# # if there is a flow mix, op must be in [mul, add, matmul]
# # element-wise op requires dim to be equal in every dim
# if any(n in node.name for n in ["mul", "add"]):
# for i in node.args:
# if type(i) == type(mix_flow_node) and i != mix_flow_node:
# main_flow_var = i
# # if mix flow is a broadcast in chunk dim,
# # TODO: need to move that flow out of the chunk
# mix_flow_node_dim = index_tracer.get_node_chunk_dim(
# self.node_list[end_idx], end_dim, node
# )
# # TODO: we need to loop every dim
# if isinstance(mix_flow_node_dim, list):
# mix_flow_node_dim = mix_flow_node_dim[0]
# if mix_flow_node_dim is None:
# flow_block = True
# break
# if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
# flow_block = False
# for i in self._get_same_flow_node(
# chunk_info["inputs"], mix_flow_node
# ):
# chunk_info["inputs"].remove(i)
# # else, we need to chunk mix var as well
# else:
# # TODO chunk another value
# flow_block = True
# break
# else:
# raise NotImplementedError("%s not implemented" % node.name)
# if flow_block:
# flow_block = True
# return flow_block, chunk_info
inputs_dim = []
remove_inputs = []
@ -908,6 +929,9 @@ class FlowTracer(object):
dim = index_tracer.get_node_chunk_dim(
self.node_list[end_idx], end_dim, input_node
)
# TODO: we need to loop every dim
if isinstance(dim, list):
dim = dim[0]
elif user_idx == end_idx:
dim = end_dim
# n has relation with chunk dim
@ -922,6 +946,8 @@ class FlowTracer(object):
if i in chunk_info["inputs"]:
chunk_info["inputs"].remove(i)
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(
self.node_list[start_idx : end_idx + 1]
@ -932,6 +958,150 @@ class FlowTracer(object):
return flow_block, chunk_info
def _assgin_single_node_flow(self, arg_node, start_idx, end_idx,
inputs, index_tracer, cur_node_dim,
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
next_node_list):
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
# find arg dim
if cur_node_dim is not None:
# dim is computed
if arg_idx in cur_node_compute[cur_node_dim]:
return False
if arg_idx not in cur_node_source[cur_node_dim]:
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
else:
arg_dim = None
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node] != arg_dim:
return False
all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim))
# else add it to list
else:
all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim}
next_node_list.append(arg_node)
return True
def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}}
while len(cur_node_list) > 0:
next_node_list = []
for cur_node in cur_node_list:
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim']
cur_node_fix_dim = all_node_info[cur_node]['fix_dim']
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
if cur_node_chunk_dim:
cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node)
cur_node_source = index_tracer._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None
# get all valid args
arg_list = []
for arg in cur_node.args:
if type(arg) != type(cur_node):
continue
if _is_non_compute_node(arg):
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx,
inputs, index_tracer, cur_node_chunk_dim,
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
next_node_list)
if flow_flag == False:
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul"]):
for arg in arg_list:
if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]['chunk_dim']
arg_fix_dim = all_node_info[arg]['fix_dim']
arg_shape = _get_node_shape(arg)
# add all dim as fix dim except chunk dim
for i, shape in enumerate(arg_shape):
if shape != 1 and i != cur_node_chunk_dim:
if i == arg_chunk_dim:
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif "einsum" in cur_node.name:
pass
elif "matmul" in cur_node.name:
pass
else:
raise NotImplementedError()
cur_node_list = next_node_list
inputs_dim = []
remove_inputs = []
for input_node in inputs:
input_dict = {}
for user in input_node.users.keys():
if _is_non_compute_node(user):
continue
user_idx = _find_idx_by_name(user.name, self.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]['chunk_dim']
if chunk_dim is not None:
input_dict[user_idx] = chunk_dim
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
"inputs_non_chunk": [],
"inputs_dim": inputs_dim,
"outputs": outputs,
"outputs_dim": end_dim,
"args": {},
}
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(
self.node_list[start_idx : end_idx + 1]
)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
class MemoryEstimator(object):
def __init__(self, index_tracer: IndexTracer) -> None:
@ -1055,12 +1225,13 @@ class MemoryEstimator(object):
node_source = self.index_tracer._find_source_trace_from_node(node)
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
for k, v in input_node_dim.items():
# TODO: inherit dim should be list too, int now
inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k])
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
return chunk_ratio
for dim, source in enumerate(node_source):
if k in source and source[k] == inherit_dim:
if k in source and inherit_dim in source[k]:
chunk_ratio = float(chunk_size) / node_shape[dim]
return chunk_ratio
return 1.
@ -1323,9 +1494,11 @@ class ChunkRegionSearch(object):
continue
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
# must be same trace idx
if start_trace_idx != end_trace_idx:
continue
if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2:
print(1)
self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
)
# dim size cannot be 1
if (
_get_node_shape(end_node)[end_dim] == 1
@ -1343,10 +1516,16 @@ class ChunkRegionSearch(object):
):
continue
# detect flow meet
flow_block, chunk_info = self.flow_tracer._detect_flow(
# flow_block, chunk_info = self.flow_tracer._detect_flow(
# start_idx, start_dim, end_idx, end_dim, self.index_tracer
# )
# if flow_block:
# continue
# flow search
chunk_info = self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
)
if flow_block:
if chunk_info is None:
continue
# check index copmute
if not self.index_tracer.check_index_duplicate(chunk_info):

View File

@ -6,6 +6,13 @@ from .ops import OutProductMean
from .triangle import PairStack
def print_memory(init_mem, text=None):
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
torch.cuda.reset_peak_memory_stats()
class EvoformerBlock(nn.Module):
def __init__(self, d_node, d_pair):
@ -16,9 +23,9 @@ class EvoformerBlock(nn.Module):
self.pair_stack = PairStack(d_pair=d_pair)
def forward(self, node, pair):
node = node + self.msa_stack(node, pair)
node = self.msa_stack(node, pair)
pair = pair + self.communication(node)
pair = pair + self.pair_stack(pair)
pair = self.pair_stack(pair)
return node, pair