mirror of https://github.com/hpcaitech/ColossalAI
format code
parent
d361d533e8
commit
ded1005667
184
chunk_codegen.py
184
chunk_codegen.py
|
@ -144,7 +144,9 @@ class IndexTracer(object):
|
|||
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
||||
else:
|
||||
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
||||
node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim)
|
||||
node_to_trace["source"][node_to_dim][node_from_idx].append(
|
||||
node_from_dim
|
||||
)
|
||||
# update inputs source
|
||||
node_to_trace["source"][node_to_dim].update(
|
||||
node_from_trace["source"][node_from_dim]
|
||||
|
@ -745,7 +747,6 @@ class IndexTracer(object):
|
|||
return True
|
||||
|
||||
|
||||
|
||||
class FlowTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
|
@ -856,7 +857,9 @@ class FlowTracer(object):
|
|||
)
|
||||
return self.flow_trace
|
||||
|
||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
||||
def _detect_flow(
|
||||
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
||||
):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
|
@ -945,8 +948,10 @@ class FlowTracer(object):
|
|||
for i in remove_inputs:
|
||||
if i in chunk_info["inputs"]:
|
||||
chunk_info["inputs"].remove(i)
|
||||
|
||||
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True)
|
||||
|
||||
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(
|
||||
chunk_info, return_dim=True
|
||||
)
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
|
@ -958,15 +963,25 @@ class FlowTracer(object):
|
|||
|
||||
return flow_block, chunk_info
|
||||
|
||||
def _assgin_single_node_flow(self, arg_node, start_idx, end_idx,
|
||||
inputs, index_tracer, cur_node_dim,
|
||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
||||
next_node_list):
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node,
|
||||
start_idx,
|
||||
end_idx,
|
||||
inputs,
|
||||
index_tracer,
|
||||
cur_node_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
||||
|
||||
# find arg dim
|
||||
if cur_node_dim is not None:
|
||||
# dim is computed
|
||||
|
@ -978,7 +993,7 @@ class FlowTracer(object):
|
|||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
|
@ -986,44 +1001,52 @@ class FlowTracer(object):
|
|||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
|
||||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim))
|
||||
all_node_info[arg_node]["fix_dim"] = list(
|
||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||
)
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim}
|
||||
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
|
||||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
||||
|
||||
def flow_search(
|
||||
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
||||
):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
|
||||
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}}
|
||||
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
next_node_list = []
|
||||
|
||||
for cur_node in cur_node_list:
|
||||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim']
|
||||
cur_node_fix_dim = all_node_info[cur_node]['fix_dim']
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = index_tracer._find_source_trace_from_node(cur_node)
|
||||
cur_node_compute = index_tracer._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = index_tracer._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.args:
|
||||
|
@ -1032,20 +1055,33 @@ class FlowTracer(object):
|
|||
if _is_non_compute_node(arg):
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx,
|
||||
inputs, index_tracer, cur_node_chunk_dim,
|
||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
||||
next_node_list)
|
||||
flow_flag = self._assgin_single_node_flow(
|
||||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
inputs,
|
||||
index_tracer,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
)
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||
for arg in arg_list:
|
||||
if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx):
|
||||
if not (
|
||||
start_idx
|
||||
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]['chunk_dim']
|
||||
arg_fix_dim = all_node_info[arg]['fix_dim']
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
arg_shape = _get_node_shape(arg)
|
||||
# add all dim as fix dim except chunk dim
|
||||
for i, shape in enumerate(arg_shape):
|
||||
|
@ -1061,7 +1097,7 @@ class FlowTracer(object):
|
|||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
|
@ -1071,7 +1107,7 @@ class FlowTracer(object):
|
|||
continue
|
||||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]['chunk_dim']
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
input_dict[user_idx] = chunk_dim
|
||||
if len(input_dict) == 0:
|
||||
|
@ -1081,7 +1117,7 @@ class FlowTracer(object):
|
|||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
|
@ -1091,7 +1127,7 @@ class FlowTracer(object):
|
|||
"outputs_dim": end_dim,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
|
@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
|
|||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list:
|
||||
|
@ -1168,12 +1204,16 @@ class MemoryEstimator(object):
|
|||
for i in delete_node:
|
||||
if i in active_list:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
|
||||
|
||||
def _get_chunk_inputs_size(
|
||||
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
|
||||
):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
||||
chunk_input_users_idx = [
|
||||
_find_idx_by_name(i.name, node_list) for i in chunk_input_users
|
||||
]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
|
@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
|
|||
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
|
||||
for k, v in input_node_dim.items():
|
||||
# TODO: inherit dim should be list too, int now
|
||||
inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k])
|
||||
inherit_dim = self.index_tracer._find_inherit_dim(
|
||||
input_node, v, self.index_tracer.nodes_list[k]
|
||||
)
|
||||
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
||||
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
||||
return chunk_ratio
|
||||
|
@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
|
|||
if k in source and inherit_dim in source[k]:
|
||||
chunk_ratio = float(chunk_size) / node_shape[dim]
|
||||
return chunk_ratio
|
||||
return 1.
|
||||
return 1.0
|
||||
|
||||
def _get_chunk_delete_node_size(
|
||||
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
||||
|
@ -1295,7 +1337,7 @@ class MemoryEstimator(object):
|
|||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_size = 1
|
||||
chunk_inputs_names = []
|
||||
|
||||
|
||||
if use_chunk:
|
||||
chunk_regions = [i["region"] for i in chunk_infos]
|
||||
chunk_starts = [i[0] for i in chunk_regions]
|
||||
|
@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
|
|||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
act_memory += self._get_output_node_size(
|
||||
chunk_outputs[chunk_region_idx]
|
||||
) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size
|
||||
node,
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_dim[chunk_region_idx],
|
||||
chunk_size,
|
||||
)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
|
@ -1353,18 +1400,18 @@ class MemoryEstimator(object):
|
|||
/ (1024**2)
|
||||
)
|
||||
# delete unused vars not in chunk_input_list
|
||||
# we can't delete input nodes until chunk ends
|
||||
# we can't delete input nodes until chunk ends
|
||||
if chunk_within:
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node,
|
||||
user_to_last_uses_no_free_var,
|
||||
chunk_ratio,
|
||||
chunk_inputs_names
|
||||
chunk_inputs_names,
|
||||
) / (1024**2)
|
||||
else:
|
||||
act_memory -= (self._get_delete_node_size(
|
||||
act_memory -= self._get_delete_node_size(
|
||||
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
||||
) / (1024**2))
|
||||
) / (1024**2)
|
||||
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
|
@ -1376,11 +1423,11 @@ class MemoryEstimator(object):
|
|||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||
)
|
||||
act_memory -= self._get_chunk_inputs_size(
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
node_list,
|
||||
chunk_regions[chunk_region_idx][1]
|
||||
) / (1024**2)
|
||||
chunk_regions[chunk_region_idx][1],
|
||||
) / (1024**2)
|
||||
chunk_within = False
|
||||
chunk_ratio = 1
|
||||
chunk_region_idx = None
|
||||
|
@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object):
|
|||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
|
@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
|
|||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||
if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2:
|
||||
if (
|
||||
start_idx == 199
|
||||
and end_idx == 229
|
||||
and start_dim == 2
|
||||
and end_dim == 2
|
||||
):
|
||||
print(1)
|
||||
self.flow_tracer.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
|
@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object):
|
|||
max_region_range = 0
|
||||
best_region = None
|
||||
return best_region
|
||||
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
|
@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object):
|
|||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
|
@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
|
||||
best_chunk_region = self._search_best_chunk_region(
|
||||
possible_chunk_regions, chunk_regions
|
||||
)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
|
@ -1667,7 +1723,11 @@ def _gen_loop_end(
|
|||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||
)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (
|
||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue