mirror of https://github.com/hpcaitech/ColossalAI
refactor flow tracer
parent
d734529a39
commit
d361d533e8
281
chunk_codegen.py
281
chunk_codegen.py
|
@ -139,7 +139,13 @@ class IndexTracer(object):
|
|||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
if init:
|
||||
node_to_trace["source"][node_to_dim] = {}
|
||||
node_to_trace["source"][node_to_dim][node_from_idx] = node_from_dim
|
||||
# add dim to cur new source
|
||||
if node_from_idx not in node_to_trace["source"][node_to_dim]:
|
||||
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
||||
else:
|
||||
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
||||
node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim)
|
||||
# update inputs source
|
||||
node_to_trace["source"][node_to_dim].update(
|
||||
node_from_trace["source"][node_from_dim]
|
||||
)
|
||||
|
@ -654,7 +660,7 @@ class IndexTracer(object):
|
|||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||
)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and node_dim == start_dim:
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
|
@ -694,12 +700,12 @@ class IndexTracer(object):
|
|||
for node_dim in range(len(_get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
and node_trace_source[node_dim][input_node_idx] == input_dim
|
||||
and input_dim in node_trace_source[node_dim][input_node_idx]
|
||||
):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def check_index_duplicate(self, chunk_infos):
|
||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
|
@ -713,17 +719,30 @@ class IndexTracer(object):
|
|||
if _is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(_get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] == v:
|
||||
count += 1
|
||||
break
|
||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||
duplicate_flag = True
|
||||
duplicate_dim.append((k, v))
|
||||
duplicate_dims.append(duplicate_dim)
|
||||
if duplicate_flag:
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
return False
|
||||
return True
|
||||
if return_dim:
|
||||
return False, duplicate_dims
|
||||
else:
|
||||
return False
|
||||
if return_dim:
|
||||
return True, None
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
@ -857,43 +876,45 @@ class FlowTracer(object):
|
|||
flow_block = True
|
||||
return flow_block, chunk_info
|
||||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
mix_flow_node = self._get_flow_mix_node(node)
|
||||
if mix_flow_node is None:
|
||||
continue
|
||||
# for idx in range(start_idx, end_idx + 1):
|
||||
# node = self.node_list[idx]
|
||||
# mix_flow_node = self._get_flow_mix_node(node)
|
||||
# if mix_flow_node is None:
|
||||
# continue
|
||||
|
||||
# if there is a flow mix, op must be in [mul, add, matmul]
|
||||
# element-wise op requires dim to be equal in every dim
|
||||
if any(n in node.name for n in ["mul", "add"]):
|
||||
for i in node.args:
|
||||
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
||||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO: need to move that flow out of the chunk
|
||||
mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, node
|
||||
)
|
||||
if mix_flow_node_dim is None:
|
||||
flow_block = True
|
||||
break
|
||||
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||
flow_block = False
|
||||
for i in self._get_same_flow_node(
|
||||
chunk_info["inputs"], mix_flow_node
|
||||
):
|
||||
chunk_info["inputs"].remove(i)
|
||||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_block = True
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
|
||||
if flow_block:
|
||||
flow_block = True
|
||||
return flow_block, chunk_info
|
||||
# # if there is a flow mix, op must be in [mul, add, matmul]
|
||||
# # element-wise op requires dim to be equal in every dim
|
||||
# if any(n in node.name for n in ["mul", "add"]):
|
||||
# for i in node.args:
|
||||
# if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
||||
# main_flow_var = i
|
||||
# # if mix flow is a broadcast in chunk dim,
|
||||
# # TODO: need to move that flow out of the chunk
|
||||
# mix_flow_node_dim = index_tracer.get_node_chunk_dim(
|
||||
# self.node_list[end_idx], end_dim, node
|
||||
# )
|
||||
# # TODO: we need to loop every dim
|
||||
# if isinstance(mix_flow_node_dim, list):
|
||||
# mix_flow_node_dim = mix_flow_node_dim[0]
|
||||
# if mix_flow_node_dim is None:
|
||||
# flow_block = True
|
||||
# break
|
||||
# if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||
# flow_block = False
|
||||
# for i in self._get_same_flow_node(
|
||||
# chunk_info["inputs"], mix_flow_node
|
||||
# ):
|
||||
# chunk_info["inputs"].remove(i)
|
||||
# # else, we need to chunk mix var as well
|
||||
# else:
|
||||
# # TODO chunk another value
|
||||
# flow_block = True
|
||||
# break
|
||||
# else:
|
||||
# raise NotImplementedError("%s not implemented" % node.name)
|
||||
# if flow_block:
|
||||
# flow_block = True
|
||||
# return flow_block, chunk_info
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
|
@ -908,6 +929,9 @@ class FlowTracer(object):
|
|||
dim = index_tracer.get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, input_node
|
||||
)
|
||||
# TODO: we need to loop every dim
|
||||
if isinstance(dim, list):
|
||||
dim = dim[0]
|
||||
elif user_idx == end_idx:
|
||||
dim = end_dim
|
||||
# n has relation with chunk dim
|
||||
|
@ -922,6 +946,8 @@ class FlowTracer(object):
|
|||
if i in chunk_info["inputs"]:
|
||||
chunk_info["inputs"].remove(i)
|
||||
|
||||
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True)
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
|
@ -932,6 +958,150 @@ class FlowTracer(object):
|
|||
|
||||
return flow_block, chunk_info
|
||||
|
||||
def _assgin_single_node_flow(self, arg_node, start_idx, end_idx,
|
||||
inputs, index_tracer, cur_node_dim,
|
||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
||||
next_node_list):
|
||||
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
||||
# find arg dim
|
||||
if cur_node_dim is not None:
|
||||
# dim is computed
|
||||
if arg_idx in cur_node_compute[cur_node_dim]:
|
||||
return False
|
||||
if arg_idx not in cur_node_source[cur_node_dim]:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
for i in cur_node_fix_dim:
|
||||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim))
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim}
|
||||
|
||||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
next_node_list = []
|
||||
|
||||
for cur_node in cur_node_list:
|
||||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim']
|
||||
cur_node_fix_dim = all_node_info[cur_node]['fix_dim']
|
||||
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = index_tracer._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.args:
|
||||
if type(arg) != type(cur_node):
|
||||
continue
|
||||
if _is_non_compute_node(arg):
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx,
|
||||
inputs, index_tracer, cur_node_chunk_dim,
|
||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
||||
next_node_list)
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||
for arg in arg_list:
|
||||
if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]['chunk_dim']
|
||||
arg_fix_dim = all_node_info[arg]['fix_dim']
|
||||
arg_shape = _get_node_shape(arg)
|
||||
# add all dim as fix dim except chunk dim
|
||||
for i, shape in enumerate(arg_shape):
|
||||
if shape != 1 and i != cur_node_chunk_dim:
|
||||
if i == arg_chunk_dim:
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif "einsum" in cur_node.name:
|
||||
pass
|
||||
elif "matmul" in cur_node.name:
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
for user in input_node.users.keys():
|
||||
if _is_non_compute_node(user):
|
||||
continue
|
||||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]['chunk_dim']
|
||||
if chunk_dim is not None:
|
||||
input_dict[user_idx] = chunk_dim
|
||||
if len(input_dict) == 0:
|
||||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
return chunk_info
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||
|
@ -1055,12 +1225,13 @@ class MemoryEstimator(object):
|
|||
node_source = self.index_tracer._find_source_trace_from_node(node)
|
||||
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
|
||||
for k, v in input_node_dim.items():
|
||||
# TODO: inherit dim should be list too, int now
|
||||
inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k])
|
||||
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
||||
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
||||
return chunk_ratio
|
||||
for dim, source in enumerate(node_source):
|
||||
if k in source and source[k] == inherit_dim:
|
||||
if k in source and inherit_dim in source[k]:
|
||||
chunk_ratio = float(chunk_size) / node_shape[dim]
|
||||
return chunk_ratio
|
||||
return 1.
|
||||
|
@ -1323,9 +1494,11 @@ class ChunkRegionSearch(object):
|
|||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||
# must be same trace idx
|
||||
if start_trace_idx != end_trace_idx:
|
||||
continue
|
||||
if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2:
|
||||
print(1)
|
||||
self.flow_tracer.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
)
|
||||
# dim size cannot be 1
|
||||
if (
|
||||
_get_node_shape(end_node)[end_dim] == 1
|
||||
|
@ -1343,10 +1516,16 @@ class ChunkRegionSearch(object):
|
|||
):
|
||||
continue
|
||||
# detect flow meet
|
||||
flow_block, chunk_info = self.flow_tracer._detect_flow(
|
||||
# flow_block, chunk_info = self.flow_tracer._detect_flow(
|
||||
# start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
# )
|
||||
# if flow_block:
|
||||
# continue
|
||||
# flow search
|
||||
chunk_info = self.flow_tracer.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
)
|
||||
if flow_block:
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.index_tracer.check_index_duplicate(chunk_info):
|
||||
|
|
|
@ -6,6 +6,13 @@ from .ops import OutProductMean
|
|||
from .triangle import PairStack
|
||||
|
||||
|
||||
def print_memory(init_mem, text=None):
|
||||
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
|
@ -16,9 +23,9 @@ class EvoformerBlock(nn.Module):
|
|||
self.pair_stack = PairStack(d_pair=d_pair)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = node + self.msa_stack(node, pair)
|
||||
node = self.msa_stack(node, pair)
|
||||
pair = pair + self.communication(node)
|
||||
pair = pair + self.pair_stack(pair)
|
||||
pair = self.pair_stack(pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue