mirror of https://github.com/hpcaitech/ColossalAI
redesign index tracer, add source and change compute
parent
2b4ebcc278
commit
979e61db92
310
chunk_codegen.py
310
chunk_codegen.py
|
@ -16,6 +16,11 @@ def _delete_free_var_from_last_use(user_to_last_uses):
|
|||
if n.op == 'placeholder':
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
def _get_node_shape(node):
|
||||
if hasattr(node.meta['tensor_meta'], "shape"):
|
||||
return node.meta['tensor_meta'].shape
|
||||
return None
|
||||
|
||||
|
||||
class FlowTracer(object):
|
||||
def __init__(self, gm) -> None:
|
||||
|
@ -136,11 +141,25 @@ class IndexTracer(object):
|
|||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_list = self._init_idx_trace_list()
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
self.idx_count = -1
|
||||
|
||||
def _init_idx_trace_list(self):
|
||||
idx_trace_list = []
|
||||
for n in self.nodes_list:
|
||||
if _get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
'idx': [None for _ in range(len(_get_node_shape(n)))],
|
||||
'compute': [[] for _ in range(len(_get_node_shape(n)))],
|
||||
'source': [[] for _ in range(len(_get_node_shape(n)))],
|
||||
}
|
||||
else:
|
||||
cur_trace = {'idx': [], 'compute': [], 'source': []}
|
||||
idx_trace_list.append(cur_trace)
|
||||
return idx_trace_list
|
||||
|
||||
def _add_index(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
@ -150,35 +169,81 @@ class IndexTracer(object):
|
|||
"""
|
||||
self.idx_count += 1
|
||||
return self.idx_count
|
||||
|
||||
def _inherit_computation(self, node_from, node_to):
|
||||
"""
|
||||
Inherit computed dim from node_from to node_to.
|
||||
If a dim in node_from is marked as computed and exists in node_to,
|
||||
still mark it as computed in node_to.
|
||||
|
||||
Args:
|
||||
node_from (node): node to be inherited
|
||||
node_to (node): new node to inherit
|
||||
"""
|
||||
_, compute_from = self._find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self._find_trace_from_node(node_to)
|
||||
for k, v in compute_from.items():
|
||||
if k in idx_to:
|
||||
if k in compute_to:
|
||||
compute_to[k].extend(v)
|
||||
else:
|
||||
compute_to[k] = copy.deepcopy(v)
|
||||
|
||||
def _mark_idx_equal(self, idx1, idx2):
|
||||
def _del_dim(self, idx, dim_idx):
|
||||
self.idx_trace_list[idx]['idx'].pop(dim_idx)
|
||||
self.idx_trace_list[idx]['compute'].pop(dim_idx)
|
||||
self.idx_trace_list[idx]['source'].pop(dim_idx)
|
||||
|
||||
def _add_dim(self, idx, dim_idx):
|
||||
self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
|
||||
self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
|
||||
self.idx_trace_list[idx]['source'].insert(dim_idx, [])
|
||||
|
||||
def _transform_index(self, node, node_dim):
|
||||
node_idx = self._find_idx_trace_from_node(node)
|
||||
dims = list(range(len(node_idx)))
|
||||
return dims[node_dim]
|
||||
|
||||
def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
|
||||
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
node_to_trace['source'][node_to_dim] = []
|
||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(len(node_from_compute)):
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
||||
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
|
||||
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
if exclude == None:
|
||||
exclude = []
|
||||
else:
|
||||
exclude = [self._transform_index(node_to, i) for i in exclude]
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
# assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
|
||||
if self._transform_index(node_to, i) in exclude:
|
||||
continue
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
for j in node_from_compute[i]:
|
||||
if j not in node_to_compute[i]:
|
||||
node_to_compute[i].append(j)
|
||||
|
||||
def _mark_idx_equal(self, node1, dim1, node2, dim2):
|
||||
"""
|
||||
Mark 2 index to be equal.
|
||||
|
||||
Args:
|
||||
idx1 (int): index count.
|
||||
idx2 (int): index count.
|
||||
"""
|
||||
self.idx_trace_equal.append((idx1, idx2))
|
||||
"""
|
||||
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
|
||||
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
|
||||
# if node1_idx > node2_idx:
|
||||
# self._add_source(node2, dim2, node1, dim1)
|
||||
# else:
|
||||
# self._add_source(node1, dim1, node2, dim2)
|
||||
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
|
@ -189,16 +254,14 @@ class IndexTracer(object):
|
|||
idx (int): node index
|
||||
dim (list or int): dims to be marked as computed
|
||||
"""
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
dims = list(range(len(_get_node_shape(node))))
|
||||
for d in dim:
|
||||
cur_idx = input_node_idx_trace[d]
|
||||
if cur_idx not in self.idx_trace_list[idx]['compute']:
|
||||
self.idx_trace_list[idx]['compute'][cur_idx] = [idx]
|
||||
else:
|
||||
self.idx_trace_list[idx]['compute'][cur_idx].append(idx)
|
||||
|
||||
cur_dim = dims[d]
|
||||
if idx not in self.idx_trace_list[idx]['compute'][cur_dim]:
|
||||
self.idx_trace_list[idx]['compute'][cur_dim].append(idx)
|
||||
|
||||
def _find_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
@ -211,7 +274,7 @@ class IndexTracer(object):
|
|||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict['idx'], node_dict['compute']
|
||||
return node_dict
|
||||
|
||||
def _find_idx_trace_from_node(self, node):
|
||||
"""
|
||||
|
@ -237,19 +300,23 @@ class IndexTracer(object):
|
|||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['compute']
|
||||
|
||||
def _assign_index_as_input(self, node, node_idx):
|
||||
def _assign_index_as_input(self, node, node_idx, input_node=None):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = node.args[0]
|
||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -275,15 +342,12 @@ class IndexTracer(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node = node.args[0]
|
||||
tranpose_dim = node.args[1:]
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]]
|
||||
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._assign_index_as_input(node, node_idx, input_node)
|
||||
self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -296,14 +360,11 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
|
||||
input_node = node.args[0]
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self._assign_index_as_input(node, node_idx, input_node)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
new_idx_trace[idx] = input_node_idx_trace[d]
|
||||
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._inherit_index(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -321,20 +382,15 @@ class IndexTracer(object):
|
|||
bias = None
|
||||
else:
|
||||
input_node, weight, bias = node.args
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
new_idx_trace[-1] = weight_idx_trace[1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
self._inherit_index(weight, 1, node, -1)
|
||||
|
||||
self._inherit_computation(input_node, node)
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0])
|
||||
self._mark_idx_equal(input_node, -1, weight, 0)
|
||||
|
||||
if bias:
|
||||
bias_idx_trace = self._find_idx_trace_from_node(bias)
|
||||
self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
self._mark_idx_equal(input_node, -1, bias, 0)
|
||||
|
||||
def _assign_matmul_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -348,18 +404,14 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
matmul_left, matmul_right = node.args
|
||||
matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left)
|
||||
matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right)
|
||||
|
||||
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace))
|
||||
new_idx_trace = copy.deepcopy(matmul_left_idx_trace)
|
||||
new_idx_trace[-1] = matmul_right_idx_trace[-1]
|
||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
assert(len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right)))
|
||||
self._assign_index_as_input(node, node_idx, matmul_left)
|
||||
self._inherit_index(matmul_right, -1, node, -1)
|
||||
|
||||
self._inherit_computation(matmul_left, node)
|
||||
self._inherit_computation(matmul_right, node)
|
||||
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||
self._mark_idx_equal(matmul_left, -1, matmul_right, -2)
|
||||
|
||||
def _assign_layernorm_index(self, node, idx):
|
||||
"""
|
||||
|
@ -372,7 +424,6 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._mark_computation(node, idx, [-1, -2])
|
||||
|
||||
def _assign_elementwise_index(self, node, idx):
|
||||
|
@ -386,9 +437,59 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, idx)
|
||||
nodes_in = []
|
||||
for node_in in node.args:
|
||||
if type(node_in) not in (int, float):
|
||||
self._inherit_computation(node_in, node)
|
||||
if type(node_in) == type(node):
|
||||
nodes_in.append(node_in)
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
assert len(nodes_in) <= 2
|
||||
if len(nodes_in) == 2:
|
||||
node_in0_shape = _get_node_shape(nodes_in[0])
|
||||
node_in1_shape = _get_node_shape(nodes_in[1])
|
||||
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
|
||||
if node_in0_shape[i] == node_in1_shape[i]:
|
||||
self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i)
|
||||
|
||||
def _assgin_no_change_index(self, node, idx):
|
||||
self._assign_index_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
|
||||
def _assign_einsum_index(self, node, idx):
|
||||
"""
|
||||
Assign index for einsum op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
patterns = node.args[0]
|
||||
input_nodes = node.args[1:]
|
||||
|
||||
patterns = patterns.replace(" ", "")
|
||||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
all_index = []
|
||||
for i in left:
|
||||
for c in i:
|
||||
all_index.append(c)
|
||||
all_index = set(all_index)
|
||||
free_index = set([i for i in right])
|
||||
sum_index = all_index - free_index
|
||||
|
||||
for right_idx, right_indice in enumerate(right):
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if right_indice in left_str:
|
||||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
|
||||
|
||||
for i in sum_index:
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if i in left_str:
|
||||
self._mark_computation(node, idx, left_str.index(i))
|
||||
break
|
||||
|
||||
def _assign_softmax_index(self, node, idx):
|
||||
"""
|
||||
|
@ -401,7 +502,6 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._mark_computation(node, idx, [node.kwargs['dim']])
|
||||
|
||||
def _assign_unsqueeze_index(self, node, node_idx):
|
||||
|
@ -412,10 +512,12 @@ class IndexTracer(object):
|
|||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
|
||||
self.idx_trace_list[node_idx]['compute'].insert(node.args[1], [])
|
||||
self.idx_trace_list[node_idx]['source'].insert(node.args[1], [])
|
||||
|
||||
def _assign_dropout_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -427,7 +529,6 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
|
||||
|
||||
def _assign_ones_like_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -439,17 +540,6 @@ class IndexTracer(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_index(node, node_idx)
|
||||
|
||||
def _assign_to_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for to op.
|
||||
1. assign new index for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
|
||||
def _assign_view_reshape_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -494,26 +584,26 @@ class IndexTracer(object):
|
|||
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||
dim_to = [dim_equal.index(False)]
|
||||
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._add_dim(node_idx, -1)
|
||||
elif len_diff == -1:
|
||||
# dim expand
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
else:
|
||||
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new index
|
||||
origin_trace = self._find_idx_trace_from_node(origin_node)
|
||||
new_trace = copy.deepcopy(origin_trace)
|
||||
self._assign_index_as_input(node, node_idx, origin_node)
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
new_trace.pop(i)
|
||||
self._del_dim(node_idx, i)
|
||||
for i in dim_to:
|
||||
new_trace.insert(i, self._add_index())
|
||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
self._add_dim(node_idx, i)
|
||||
|
||||
# inherit computation
|
||||
self._inherit_computation(origin_node, node)
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
|
@ -524,15 +614,10 @@ class IndexTracer(object):
|
|||
# log view, not used now
|
||||
view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
|
||||
"dim_from": dim_from,
|
||||
"idx_to": [new_trace[i] for i in dim_to],
|
||||
"idx_to": [self.idx_trace_list[node_idx]['idx'][i] for i in dim_to],
|
||||
"dim_to": dim_to}
|
||||
self.idx_view_list.append(view_dict)
|
||||
|
||||
def _remove_duplicate_compute(self):
|
||||
for i in self.idx_trace_list:
|
||||
for k, v in i['compute'].items():
|
||||
i['compute'][k] = list(set(v))
|
||||
|
||||
|
||||
def _merge_equal_idx(self):
|
||||
idx_equal = copy.deepcopy(self.idx_trace_equal)
|
||||
idx_equal.reverse()
|
||||
|
@ -556,8 +641,8 @@ class IndexTracer(object):
|
|||
self._assign_view_reshape_index(node, idx)
|
||||
elif 'unsqueeze' in node.name:
|
||||
self._assign_unsqueeze_index(node, idx)
|
||||
elif 'to' in node.name:
|
||||
self._assign_to_index(node, idx)
|
||||
elif any(i in node.name for i in ['to', 'contiguous']):
|
||||
self._assgin_no_change_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == 'call_function':
|
||||
|
@ -573,6 +658,8 @@ class IndexTracer(object):
|
|||
self._assign_ones_like_index(node, idx)
|
||||
elif 'dropout' in node.name:
|
||||
self._assign_dropout_index(node, idx)
|
||||
elif 'einsum' in node.name:
|
||||
self._assign_einsum_index(node, idx)
|
||||
elif 'getattr' in node.name:
|
||||
continue # get attr like shape
|
||||
elif 'getitem' in node.name:
|
||||
|
@ -590,10 +677,20 @@ class IndexTracer(object):
|
|||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
self._remove_duplicate_compute()
|
||||
self._merge_equal_idx()
|
||||
|
||||
# self._merge_equal_idx()
|
||||
|
||||
def check_index(self, trace_idx, start_idx, end_idx):
|
||||
for i in range(start_idx, end_idx + 1):
|
||||
cur_idx = self.idx_trace_list[i]['idx']
|
||||
cur_compute = self.idx_trace_list[i]['compute']
|
||||
if trace_idx in cur_compute:
|
||||
for j in cur_compute[trace_idx]:
|
||||
if j < start_idx or j > end_idx:
|
||||
return False
|
||||
# same_idx = [1 if j == trace_idx else 0 for j in cur_idx]
|
||||
# if sum(same_idx) > 1:
|
||||
# return False
|
||||
return True
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self) -> None:
|
||||
|
@ -897,6 +994,8 @@ class ChunkRegionSearch(object):
|
|||
self._is_not_compute(after_trace, (start_idx, end_idx), i) and
|
||||
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
|
||||
continue
|
||||
if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx):
|
||||
continue
|
||||
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
|
||||
if flow_flag == None:
|
||||
continue
|
||||
|
@ -910,7 +1009,10 @@ class ChunkRegionSearch(object):
|
|||
input_trace = []
|
||||
for i, n in enumerate(self.node_list):
|
||||
if len(n.args) > 0 and n.op != 'output':
|
||||
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
|
||||
if isinstance(n.args[0], str):
|
||||
input_idx = _find_idx_by_name(n.args[1].name, self.node_list)
|
||||
else:
|
||||
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
|
||||
input_trace.append(output_trace[input_idx])
|
||||
else:
|
||||
input_trace.append(None)
|
||||
|
@ -930,7 +1032,7 @@ class ChunkRegionSearch(object):
|
|||
if len(free_dim) > 0:
|
||||
free_dim = [free_dim[0]]
|
||||
chunk_info = [chunk_info[0]]
|
||||
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
|
||||
possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info})
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||
|
@ -1130,6 +1232,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
|
||||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
region_idx = chunk_starts.index(node_idx)
|
||||
|
||||
# add for loop
|
||||
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
||||
|
@ -1150,7 +1253,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]))
|
||||
within_chunk_region = False
|
||||
region_idx += 1
|
||||
|
||||
node_idx += 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue