mirror of https://github.com/hpcaitech/ColossalAI
pass outproduct mean
parent
979e61db92
commit
929445116a
315
chunk_codegen.py
315
chunk_codegen.py
|
@ -16,16 +16,31 @@ def _delete_free_var_from_last_use(user_to_last_uses):
|
||||||
if n.op == 'placeholder':
|
if n.op == 'placeholder':
|
||||||
user_to_last_uses[key].remove(n)
|
user_to_last_uses[key].remove(n)
|
||||||
|
|
||||||
|
|
||||||
def _get_node_shape(node):
|
def _get_node_shape(node):
|
||||||
if hasattr(node.meta['tensor_meta'], "shape"):
|
if hasattr(node.meta['tensor_meta'], "shape"):
|
||||||
return node.meta['tensor_meta'].shape
|
return node.meta['tensor_meta'].shape
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_compute_node(node):
|
||||||
|
if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \
|
||||||
|
any(i in node.name for i in ['getitem', 'getattr']):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_non_compute_node_except_placeholder(node):
|
||||||
|
if any(i in node.op for i in ['get_attr', 'output']) or \
|
||||||
|
any(i in node.name for i in ['getitem', 'getattr']):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FlowTracer(object):
|
class FlowTracer(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.nodes_list = list(gm.graph.nodes)
|
self.node_list = list(gm.graph.nodes)
|
||||||
self.flow_trace = {}
|
self.flow_trace = {}
|
||||||
|
|
||||||
def _add_trace(self, name):
|
def _add_trace(self, name):
|
||||||
|
@ -49,7 +64,7 @@ class FlowTracer(object):
|
||||||
raise RuntimeError("node not found")
|
raise RuntimeError("node not found")
|
||||||
|
|
||||||
def _init_trace(self):
|
def _init_trace(self):
|
||||||
for i in self.nodes_list:
|
for i in self.node_list:
|
||||||
if i.op == 'placeholder':
|
if i.op == 'placeholder':
|
||||||
self._add_trace(i.name)
|
self._add_trace(i.name)
|
||||||
self._add_node(i.name, i)
|
self._add_node(i.name, i)
|
||||||
|
@ -67,7 +82,7 @@ class FlowTracer(object):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _find_flow_for_node(self, node):
|
def _find_flow_for_node(self, node):
|
||||||
if type(self.nodes_list[0]) != type(node):
|
if type(self.node_list[0]) != type(node):
|
||||||
return None
|
return None
|
||||||
if self._is_non_compute_node_except_placeholder(node):
|
if self._is_non_compute_node_except_placeholder(node):
|
||||||
return None
|
return None
|
||||||
|
@ -117,7 +132,7 @@ class FlowTracer(object):
|
||||||
# init trace
|
# init trace
|
||||||
self._init_trace()
|
self._init_trace()
|
||||||
|
|
||||||
for node in self.nodes_list:
|
for node in self.node_list:
|
||||||
# skip if non compute node
|
# skip if non compute node
|
||||||
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
|
if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \
|
||||||
or self._is_non_compute_node(node):
|
or self._is_non_compute_node(node):
|
||||||
|
@ -135,6 +150,41 @@ class FlowTracer(object):
|
||||||
else:
|
else:
|
||||||
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
|
self._add_outside_depend(node_domin_flow, node, arg, node_input_flow)
|
||||||
return self.flow_trace
|
return self.flow_trace
|
||||||
|
|
||||||
|
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
|
inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
|
||||||
|
chunk_info = {'region': (start_idx, end_idx),
|
||||||
|
'inputs': inputs, 'inputs_dim': start_dim,
|
||||||
|
'outputs': outputs, 'outputs_dim': end_dim,
|
||||||
|
'args': {}}
|
||||||
|
flow_flag = False
|
||||||
|
|
||||||
|
for idx in range(start_idx, end_idx + 1):
|
||||||
|
node = self.node_list[idx]
|
||||||
|
mix_flow_var = self.get_flow_mix(node)
|
||||||
|
if mix_flow_var is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
||||||
|
# element-wise op requires dim to be equal in every dim
|
||||||
|
if any(n in node.name for n in ['mul', 'add']):
|
||||||
|
for i in node.args:
|
||||||
|
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
||||||
|
main_flow_var = i
|
||||||
|
# if mix flow is a broadcast in chunk dim,
|
||||||
|
# TODO need to move that flow out of the chunk
|
||||||
|
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
|
||||||
|
flow_flag = True
|
||||||
|
for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
|
||||||
|
chunk_info['inputs'].remove(i)
|
||||||
|
# else, we need to chunk mix var as well
|
||||||
|
else:
|
||||||
|
# TODO chunk another value
|
||||||
|
flow_flag = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("%s not implemented" % node.name)
|
||||||
|
return flow_flag, chunk_info
|
||||||
|
|
||||||
|
|
||||||
class IndexTracer(object):
|
class IndexTracer(object):
|
||||||
|
@ -153,7 +203,7 @@ class IndexTracer(object):
|
||||||
cur_trace = {
|
cur_trace = {
|
||||||
'idx': [None for _ in range(len(_get_node_shape(n)))],
|
'idx': [None for _ in range(len(_get_node_shape(n)))],
|
||||||
'compute': [[] for _ in range(len(_get_node_shape(n)))],
|
'compute': [[] for _ in range(len(_get_node_shape(n)))],
|
||||||
'source': [[] for _ in range(len(_get_node_shape(n)))],
|
'source': [{} for _ in range(len(_get_node_shape(n)))],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cur_trace = {'idx': [], 'compute': [], 'source': []}
|
cur_trace = {'idx': [], 'compute': [], 'source': []}
|
||||||
|
@ -178,7 +228,7 @@ class IndexTracer(object):
|
||||||
def _add_dim(self, idx, dim_idx):
|
def _add_dim(self, idx, dim_idx):
|
||||||
self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
|
self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
|
||||||
self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
|
self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
|
||||||
self.idx_trace_list[idx]['source'].insert(dim_idx, [])
|
self.idx_trace_list[idx]['source'].insert(dim_idx, {})
|
||||||
|
|
||||||
def _transform_index(self, node, node_dim):
|
def _transform_index(self, node, node_dim):
|
||||||
node_idx = self._find_idx_trace_from_node(node)
|
node_idx = self._find_idx_trace_from_node(node)
|
||||||
|
@ -192,10 +242,7 @@ class IndexTracer(object):
|
||||||
node_to_trace = self._find_trace_from_node(node_to)
|
node_to_trace = self._find_trace_from_node(node_to)
|
||||||
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
|
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
|
||||||
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
|
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
|
||||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||||
node_to_trace['source'][node_to_dim] = []
|
|
||||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
|
||||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
|
||||||
|
|
||||||
def _inherit_all_computation(self, node_from, node_to):
|
def _inherit_all_computation(self, node_from, node_to):
|
||||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||||
|
@ -205,14 +252,16 @@ class IndexTracer(object):
|
||||||
self._add_source(node_from, i, node_to, i)
|
self._add_source(node_from, i, node_to, i)
|
||||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||||
|
|
||||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim):
|
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||||
node_from_trace = self._find_trace_from_node(node_from)
|
node_from_trace = self._find_trace_from_node(node_from)
|
||||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||||
node_to_trace = self._find_trace_from_node(node_to)
|
node_to_trace = self._find_trace_from_node(node_to)
|
||||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
if init:
|
||||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
node_to_trace['source'][node_to_dim] = {}
|
||||||
|
node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim
|
||||||
|
node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim])
|
||||||
|
|
||||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||||
if exclude == None:
|
if exclude == None:
|
||||||
|
@ -485,11 +534,11 @@ class IndexTracer(object):
|
||||||
source_idx = left_str.index(right_indice)
|
source_idx = left_str.index(right_indice)
|
||||||
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
|
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
|
||||||
|
|
||||||
for i in sum_index:
|
# for i in sum_index:
|
||||||
for left_idx, left_str in enumerate(left):
|
# for left_idx, left_str in enumerate(left):
|
||||||
if i in left_str:
|
# if i in left_str:
|
||||||
self._mark_computation(node, idx, left_str.index(i))
|
# self._mark_computation(node, idx, left_str.index(i))
|
||||||
break
|
# break
|
||||||
|
|
||||||
def _assign_softmax_index(self, node, idx):
|
def _assign_softmax_index(self, node, idx):
|
||||||
"""
|
"""
|
||||||
|
@ -679,18 +728,56 @@ class IndexTracer(object):
|
||||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||||
# self._merge_equal_idx()
|
# self._merge_equal_idx()
|
||||||
|
|
||||||
def check_index(self, trace_idx, start_idx, end_idx):
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
for i in range(start_idx, end_idx + 1):
|
"""
|
||||||
cur_idx = self.idx_trace_list[i]['idx']
|
Check 2 given index: one index should be source of the other
|
||||||
cur_compute = self.idx_trace_list[i]['compute']
|
Args:
|
||||||
if trace_idx in cur_compute:
|
start_idx(int): start node chunk dim
|
||||||
for j in cur_compute[trace_idx]:
|
start_node(node): start node
|
||||||
if j < start_idx or j > end_idx:
|
end_idx(int): end node chunk dim
|
||||||
return False
|
end_node(node): end node
|
||||||
# same_idx = [1 if j == trace_idx else 0 for j in cur_idx]
|
|
||||||
# if sum(same_idx) > 1:
|
Returns:
|
||||||
# return False
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
|
||||||
|
end_node_trace = self._find_trace_from_node(end_node)
|
||||||
|
end_node_trace_source = end_node_trace['source'][end_dim]
|
||||||
|
sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True)
|
||||||
|
for node_idx, node_dim in sorted_source:
|
||||||
|
if node_idx == start_node_idx and node_dim == start_dim:
|
||||||
|
return True
|
||||||
|
# it means we meet a node outside the loop, and the node is not input node
|
||||||
|
if node_idx < start_idx:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||||
|
"""
|
||||||
|
Check 2 given index: check they haven't been computed in the source trace.
|
||||||
|
Args:
|
||||||
|
start_idx(int): start node chunk dim
|
||||||
|
start_node(node): start node
|
||||||
|
end_idx(int): end node chunk dim
|
||||||
|
end_node(node): end node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
end_node_trace = self._find_trace_from_node(end_node)
|
||||||
|
end_node_compute = end_node_trace['compute'][end_dim]
|
||||||
|
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
# end_node_trace_source = end_node_trace['source'][end_dim]
|
||||||
|
# for node_idx, node_dim in end_node_trace_source.items():
|
||||||
|
# if node_idx < start_node_idx or node_idx > end_node_idx:
|
||||||
|
# continue
|
||||||
|
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
|
||||||
|
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
|
||||||
|
# return False
|
||||||
|
# return True
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class MemoryEstimator(object):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -951,88 +1038,81 @@ class ChunkRegionSearch(object):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx):
|
def _check_duplicate_map(self, chunk_infos):
|
||||||
inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1])
|
dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos]
|
||||||
chunk_info = {'inputs': inputs, 'outputs': outputs}
|
remove_list = []
|
||||||
flow_flag = False
|
for idx1, (input_dim1, output_dim1) in enumerate(dim_map):
|
||||||
|
for idx2, (input_dim2, output_dim2) in enumerate(dim_map):
|
||||||
for idx in range(start_idx, end_idx + 1):
|
if idx1 == idx2:
|
||||||
node = self.node_list[idx]
|
continue
|
||||||
mix_flow_var = self.flow_tracer.get_flow_mix(node)
|
# it means an index create 2 copy of itself
|
||||||
if mix_flow_var is None:
|
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
||||||
continue
|
# TODO currently remove it, deal with this in future
|
||||||
|
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
||||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
remove_list.append(chunk_infos[idx1])
|
||||||
# element-wise op requires dim to be equal in every dim
|
remove_list.append(chunk_infos[idx2])
|
||||||
if any(n in node.name for n in ['mul', 'add']):
|
for i in remove_list:
|
||||||
for i in node.args:
|
if i in chunk_infos:
|
||||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
chunk_infos.remove(i)
|
||||||
main_flow_var = i
|
return chunk_infos
|
||||||
# if mix flow is a broadcast in chunk dim,
|
|
||||||
# TODO need to move that flow out of the chunk
|
|
||||||
if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1:
|
|
||||||
flow_flag = True
|
|
||||||
for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var):
|
|
||||||
chunk_info['inputs'].remove(i)
|
|
||||||
# else, we need to chunk mix var as well
|
|
||||||
else:
|
|
||||||
# TODO chunk another value
|
|
||||||
flow_flag = False
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("%s not implemented" % node.name)
|
|
||||||
return flow_flag, chunk_info
|
|
||||||
|
|
||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||||
before_trace = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
after_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
free_dim = []
|
end_node = self.node_list[end_idx]
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
for end_dim, end_trace_idx in enumerate(end_trace['idx']):
|
||||||
if not (before_trace['idx'][i] == after_trace['idx'][i] and
|
if len(start_traces) > 1:
|
||||||
self._is_not_compute(before_trace, (start_idx, end_idx), i) and
|
# TODO implement multi input chunk
|
||||||
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
|
|
||||||
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
|
|
||||||
continue
|
continue
|
||||||
if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx):
|
for start_node, start_trace in start_traces.items():
|
||||||
continue
|
for start_dim, start_trace_idx in enumerate(start_trace['idx']):
|
||||||
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
|
# must be same trace idx
|
||||||
if flow_flag == None:
|
if start_trace_idx != end_trace_idx:
|
||||||
continue
|
continue
|
||||||
chunk_infos.append(chunk_info)
|
# dim size cannot be 1
|
||||||
free_dim.append(i)
|
if _get_node_shape(end_node)[end_dim] == 1 or \
|
||||||
return free_dim, chunk_infos
|
_get_node_shape(start_node)[start_dim] == 1:
|
||||||
|
continue
|
||||||
|
# check index source align
|
||||||
|
if not self.index_tracer.check_index_source(
|
||||||
|
start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
|
continue
|
||||||
|
# check index copmute
|
||||||
|
if not self.index_tracer.check_index_compute(
|
||||||
|
start_idx, end_dim, end_node, end_idx):
|
||||||
|
continue
|
||||||
|
# detect flow meet
|
||||||
|
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
||||||
|
start_idx, start_dim, end_idx, end_dim)
|
||||||
|
if flow_flag:
|
||||||
|
continue
|
||||||
|
chunk_infos.append(chunk_info)
|
||||||
|
chunk_infos = self._check_duplicate_map(chunk_infos)
|
||||||
|
return chunk_infos
|
||||||
|
|
||||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||||
input_trace = []
|
input_trace = [] # trace of a node's input nodes
|
||||||
for i, n in enumerate(self.node_list):
|
for _, n in enumerate(self.node_list):
|
||||||
if len(n.args) > 0 and n.op != 'output':
|
cur_trace = {}
|
||||||
if isinstance(n.args[0], str):
|
for arg in n.args:
|
||||||
input_idx = _find_idx_by_name(n.args[1].name, self.node_list)
|
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg):
|
||||||
else:
|
cur_trace[arg] = self.index_tracer._find_trace_from_node(arg)
|
||||||
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
|
input_trace.append(cur_trace)
|
||||||
input_trace.append(output_trace[input_idx])
|
|
||||||
else:
|
|
||||||
input_trace.append(None)
|
|
||||||
|
|
||||||
for start_idx in range(max_chunk_region[0], peak_node):
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
# skip non compute nodes
|
# skip non compute nodes
|
||||||
if any(op in ['placeholder', 'get_attr', 'output'] for op in
|
if _is_non_compute_node(self.node_list[start_idx]) or \
|
||||||
[self.node_list[start_idx].op, self.node_list[end_idx].op]):
|
_is_non_compute_node(self.node_list[end_idx]):
|
||||||
continue
|
|
||||||
if any(any(i in name for i in ['getitem', 'getattr']) for name in
|
|
||||||
[self.node_list[start_idx].name, self.node_list[end_idx].name]):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# select free dim
|
# select free dim
|
||||||
free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
|
chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx)
|
||||||
if len(free_dim) > 0:
|
if len(chunk_info) > 0:
|
||||||
free_dim = [free_dim[0]]
|
possible_chunk_region.extend(chunk_info)
|
||||||
chunk_info = [chunk_info[0]]
|
|
||||||
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
|
|
||||||
return possible_chunk_region
|
return possible_chunk_region
|
||||||
|
|
||||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||||
|
@ -1044,7 +1124,8 @@ class ChunkRegionSearch(object):
|
||||||
max_region_range = i['region'][1] - i['region'][0]
|
max_region_range = i['region'][1] - i['region'][0]
|
||||||
return best_regions
|
return best_regions
|
||||||
|
|
||||||
def _step_search(self, peak_node, active_node):
|
def _step_search(self, mem_peak, active_node):
|
||||||
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
||||||
|
@ -1062,19 +1143,16 @@ class ChunkRegionSearch(object):
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
chunk_region = self._step_search(mem_peak, active_node)
|
||||||
chunk_region = self._step_search(peak_node, active_node)
|
if chunk_region is None:
|
||||||
if chunk_region is None or len(chunk_region['dim']) == 0:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
chunk_regions.append(chunk_region)
|
chunk_regions.append(chunk_region)
|
||||||
mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(
|
mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
self.gm, [i['region'][0] for i in chunk_regions],
|
self.gm, [i['region'][0] for i in chunk_regions],
|
||||||
[i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions))
|
[i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions))
|
||||||
|
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
break
|
break
|
||||||
|
|
||||||
return chunk_regions
|
return chunk_regions
|
||||||
|
|
||||||
|
|
||||||
|
@ -1164,6 +1242,35 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
return input_nodes, output_nodes
|
return input_nodes, output_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Find non-compute input and output node names.
|
||||||
|
input nodes are nodes used in the list
|
||||||
|
output nodes are nodes will use nodes in the list
|
||||||
|
"""
|
||||||
|
input_nodes = []
|
||||||
|
output_nodes = []
|
||||||
|
|
||||||
|
# if a node has an input node which is not in the node list
|
||||||
|
# we treat that input node as the input of the checkpoint function
|
||||||
|
for node in nodes:
|
||||||
|
for input_node in node._input_nodes.keys():
|
||||||
|
if input_node not in nodes and input_node not in input_nodes \
|
||||||
|
and not _is_non_compute_node_except_placeholder(input_node):
|
||||||
|
input_nodes.append(input_node)
|
||||||
|
|
||||||
|
# if a node has a user node which is not in the node list
|
||||||
|
# we treat that user node as the node receiving the current node output
|
||||||
|
# TODO it is unsafe to remove non compute node here
|
||||||
|
for node in nodes:
|
||||||
|
for output_node in node.users.keys():
|
||||||
|
if output_node not in nodes and node not in output_nodes \
|
||||||
|
and not _is_non_compute_node_except_placeholder(input_node):
|
||||||
|
output_nodes.append(node)
|
||||||
|
|
||||||
|
return input_nodes, output_nodes
|
||||||
|
|
||||||
|
|
||||||
def _find_idx_by_name(name, nodes_list):
|
def _find_idx_by_name(name, nodes_list):
|
||||||
for idx, node in enumerate(nodes_list):
|
for idx, node in enumerate(nodes_list):
|
||||||
if node.name == name:
|
if node.name == name:
|
||||||
|
|
Loading…
Reference in New Issue