mirror of https://github.com/hpcaitech/ColossalAI
work with outerproductmean and msa
parent
5de9e46381
commit
31a2c5d09f
246
chunk_codegen.py
246
chunk_codegen.py
|
@ -134,7 +134,7 @@ class FlowTracer(object):
|
||||||
return name, i
|
return name, i
|
||||||
raise RuntimeError("invalid node")
|
raise RuntimeError("invalid node")
|
||||||
|
|
||||||
def get_flow_mix(self, node):
|
def _get_flow_mix_node(self, node):
|
||||||
if self._is_non_compute_node(node):
|
if self._is_non_compute_node(node):
|
||||||
return None
|
return None
|
||||||
_, node_trace = self.find_node_flow(node)
|
_, node_trace = self.find_node_flow(node)
|
||||||
|
@ -145,7 +145,7 @@ class FlowTracer(object):
|
||||||
vars = list(node_trace["outside_depend"][0].values())[0]
|
vars = list(node_trace["outside_depend"][0].values())[0]
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
def get_same_flow_node(self, node_list, node):
|
def _get_same_flow_node(self, node_list, node):
|
||||||
name, _ = self.find_node_flow(node)
|
name, _ = self.find_node_flow(node)
|
||||||
result = []
|
result = []
|
||||||
for i in self.flow_trace[name]:
|
for i in self.flow_trace[name]:
|
||||||
|
@ -181,13 +181,14 @@ class FlowTracer(object):
|
||||||
)
|
)
|
||||||
return self.flow_trace
|
return self.flow_trace
|
||||||
|
|
||||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
|
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer):
|
||||||
inputs, outputs = _find_chunk_input_and_output_nodes(
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
)
|
)
|
||||||
chunk_info = {
|
chunk_info = {
|
||||||
"region": (start_idx, end_idx),
|
"region": (start_idx, end_idx),
|
||||||
"inputs": inputs,
|
"inputs": inputs,
|
||||||
|
"inputs_non_chunk": [],
|
||||||
"inputs_dim": start_dim,
|
"inputs_dim": start_dim,
|
||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
"outputs_dim": end_dim,
|
"outputs_dim": end_dim,
|
||||||
|
@ -197,31 +198,71 @@ class FlowTracer(object):
|
||||||
|
|
||||||
for idx in range(start_idx, end_idx + 1):
|
for idx in range(start_idx, end_idx + 1):
|
||||||
node = self.node_list[idx]
|
node = self.node_list[idx]
|
||||||
mix_flow_var = self.get_flow_mix(node)
|
mix_flow_node = self._get_flow_mix_node(node)
|
||||||
if mix_flow_var is None:
|
if mix_flow_node is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
# if there is a flow mix, op must be in [mul, add, matmul]
|
||||||
# element-wise op requires dim to be equal in every dim
|
# element-wise op requires dim to be equal in every dim
|
||||||
if any(n in node.name for n in ["mul", "add"]):
|
if any(n in node.name for n in ["mul", "add"]):
|
||||||
for i in node.args:
|
for i in node.args:
|
||||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
||||||
main_flow_var = i
|
main_flow_var = i
|
||||||
# if mix flow is a broadcast in chunk dim,
|
# if mix flow is a broadcast in chunk dim,
|
||||||
# TODO need to move that flow out of the chunk
|
# TODO need to move that flow out of the chunk
|
||||||
if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1:
|
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
||||||
|
self.node_list[end_idx], end_dim, node
|
||||||
|
)
|
||||||
|
if mix_flow_node_dim is None:
|
||||||
flow_flag = True
|
flow_flag = True
|
||||||
for i in self.get_same_flow_node(
|
break
|
||||||
chunk_info["inputs"], mix_flow_var
|
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||||
|
flow_flag = False
|
||||||
|
for i in self._get_same_flow_node(
|
||||||
|
chunk_info["inputs"], mix_flow_node
|
||||||
):
|
):
|
||||||
chunk_info["inputs"].remove(i)
|
chunk_info["inputs"].remove(i)
|
||||||
# else, we need to chunk mix var as well
|
# else, we need to chunk mix var as well
|
||||||
else:
|
else:
|
||||||
# TODO chunk another value
|
# TODO chunk another value
|
||||||
flow_flag = False
|
flow_flag = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("%s not implemented" % node.name)
|
raise NotImplementedError("%s not implemented" % node.name)
|
||||||
|
|
||||||
|
inputs_dim = []
|
||||||
|
remove_inputs = []
|
||||||
|
for input_node in chunk_info['inputs']:
|
||||||
|
input_dict = {}
|
||||||
|
for user in input_node.users.keys():
|
||||||
|
if _is_non_compute_node(user):
|
||||||
|
continue
|
||||||
|
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||||
|
dim = None
|
||||||
|
if start_dim <= user_idx < end_idx:
|
||||||
|
dim = index_tracer._get_node_chunk_dim(
|
||||||
|
self.node_list[end_idx], end_dim, input_node
|
||||||
|
)
|
||||||
|
elif user_idx == end_idx:
|
||||||
|
dim = end_dim
|
||||||
|
# n has relation with chunk dim
|
||||||
|
if dim is not None and _get_node_shape(user)[dim] != 1:
|
||||||
|
input_dict[user_idx] = dim
|
||||||
|
if len(input_dict) == 0:
|
||||||
|
remove_inputs.append(input_node)
|
||||||
|
else:
|
||||||
|
inputs_dim.append(input_dict)
|
||||||
|
chunk_info['inputs_dim'] = inputs_dim
|
||||||
|
for i in remove_inputs:
|
||||||
|
if i in chunk_info['inputs']:
|
||||||
|
chunk_info['inputs'].remove(i)
|
||||||
|
|
||||||
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
|
non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1])
|
||||||
|
for i in non_chunk_inputs:
|
||||||
|
if i not in chunk_info['inputs']:
|
||||||
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
|
||||||
return flow_flag, chunk_info
|
return flow_flag, chunk_info
|
||||||
|
|
||||||
|
|
||||||
|
@ -367,6 +408,20 @@ class IndexTracer(object):
|
||||||
node_dict = self.idx_trace_list[node_idx]
|
node_dict = self.idx_trace_list[node_idx]
|
||||||
return node_dict
|
return node_dict
|
||||||
|
|
||||||
|
def _find_source_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node source trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||||
|
node_dict = self.idx_trace_list[node_idx]
|
||||||
|
return node_dict["source"]
|
||||||
|
|
||||||
def _find_idx_trace_from_node(self, node):
|
def _find_idx_trace_from_node(self, node):
|
||||||
"""
|
"""
|
||||||
Find node idx trace by the node.
|
Find node idx trace by the node.
|
||||||
|
@ -836,6 +891,15 @@ class IndexTracer(object):
|
||||||
# return False
|
# return False
|
||||||
# return True
|
# return True
|
||||||
|
|
||||||
|
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
|
node_from_source = self._find_source_trace_from_node(node_from)
|
||||||
|
dim_source = node_from_source[node_from_dim]
|
||||||
|
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if k == node_to_idx:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class MemoryEstimator(object):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -931,8 +995,10 @@ class MemoryEstimator(object):
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
||||||
|
sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0])
|
||||||
|
dim = list(sorted_dim[-1].values())[0]
|
||||||
shape = node.meta["tensor_meta"].shape
|
shape = node.meta["tensor_meta"].shape
|
||||||
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
chunk_ratio = float(chunk_size) / shape[dim]
|
||||||
return chunk_ratio
|
return chunk_ratio
|
||||||
|
|
||||||
def _get_chunk_delete_node_size(
|
def _get_chunk_delete_node_size(
|
||||||
|
@ -1157,6 +1223,8 @@ class ChunkRegionSearch(object):
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||||
|
if start_idx == 71 and end_idx == 126:
|
||||||
|
print(1)
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
end_node = self.node_list[end_idx]
|
end_node = self.node_list[end_idx]
|
||||||
|
@ -1188,7 +1256,7 @@ class ChunkRegionSearch(object):
|
||||||
continue
|
continue
|
||||||
# detect flow meet
|
# detect flow meet
|
||||||
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
||||||
start_idx, start_dim, end_idx, end_dim
|
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||||
)
|
)
|
||||||
if flow_flag:
|
if flow_flag:
|
||||||
continue
|
continue
|
||||||
|
@ -1301,56 +1369,53 @@ def _get_first_non_single_dim(shape):
|
||||||
raise RuntimeError("can not get first non single dim for shape", shape)
|
raise RuntimeError("can not get first non single dim for shape", shape)
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2):
|
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
||||||
if len(chunk_input_meta) == 1:
|
input_node = chunk_input[0]
|
||||||
node = chunk_input_meta[0]
|
|
||||||
node_shape = node.meta["tensor_meta"].shape
|
out_shape = _get_node_shape(chunk_output)
|
||||||
free_shape = [
|
out_str = str(list(out_shape))
|
||||||
node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
|
|
||||||
]
|
|
||||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
|
||||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
|
||||||
out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
|
|
||||||
|
|
||||||
context = (
|
context = (
|
||||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
|
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||||
% (out_shape, node.name, node.name, chunk_size)
|
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||||
)
|
|
||||||
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
|
||||||
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"input with size %d not implemented" % len(chunk_input_meta)
|
|
||||||
)
|
)
|
||||||
|
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||||
|
|
||||||
|
# node = chunk_input[0]
|
||||||
|
# node_shape = node.meta["tensor_meta"].shape
|
||||||
|
# free_shape = [
|
||||||
|
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
|
||||||
|
# ]
|
||||||
|
# chunk_dim = _get_first_non_single_dim(free_shape)
|
||||||
|
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||||
|
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
|
||||||
|
|
||||||
|
# context = (
|
||||||
|
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
|
||||||
|
# % (out_shape, node.name, node.name, chunk_size)
|
||||||
|
# )
|
||||||
|
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||||
|
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim):
|
def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list):
|
||||||
chunk_inputs_name = chunk_inputs[0].name
|
|
||||||
chunk_outputs_name = chunk_outputs.name
|
chunk_outputs_name = chunk_outputs.name
|
||||||
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||||
free_shape = [
|
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||||
chunk_output_shape[i] if i in chunk_dim else 1
|
|
||||||
for i in range(len(chunk_output_shape))
|
|
||||||
]
|
|
||||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
|
||||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
|
||||||
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||||
|
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||||
context += (
|
|
||||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
|
||||||
)
|
|
||||||
|
|
||||||
# determine if its the last use for chunk input
|
# determine if its the last use for chunk input
|
||||||
users_name = list(chunk_inputs[0].users.keys())
|
for chunk_input in (chunk_inputs + chunk_non_compute_inputs):
|
||||||
if all(
|
if all(
|
||||||
[
|
[
|
||||||
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||||
for user in users_name
|
for user in chunk_input.users.keys()
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
context += "; %s = None" % chunk_inputs_name
|
context += "; %s = None" % chunk_input.name
|
||||||
|
|
||||||
context += "\n"
|
context += "\n"
|
||||||
return context
|
return context
|
||||||
|
@ -1382,7 +1447,24 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
return input_nodes, output_nodes
|
return input_nodes, output_nodes
|
||||||
|
|
||||||
|
|
||||||
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
def _find_chunk_all_input_nodes(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Find non-compute input and output node names.
|
||||||
|
input nodes are nodes used in the list
|
||||||
|
output nodes are nodes will use nodes in the list
|
||||||
|
"""
|
||||||
|
input_nodes = []
|
||||||
|
for node in nodes:
|
||||||
|
for input_node in node._input_nodes.keys():
|
||||||
|
if (
|
||||||
|
input_node not in nodes
|
||||||
|
and input_node not in input_nodes
|
||||||
|
):
|
||||||
|
input_nodes.append(input_node)
|
||||||
|
return input_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||||
"""
|
"""
|
||||||
Find non-compute input and output node names.
|
Find non-compute input and output node names.
|
||||||
input nodes are nodes used in the list
|
input nodes are nodes used in the list
|
||||||
|
@ -1410,7 +1492,7 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
||||||
if (
|
if (
|
||||||
output_node not in nodes
|
output_node not in nodes
|
||||||
and node not in output_nodes
|
and node not in output_nodes
|
||||||
and not _is_non_compute_node_except_placeholder(input_node)
|
and not _is_non_compute_node_except_placeholder(output_node)
|
||||||
):
|
):
|
||||||
output_nodes.append(node)
|
output_nodes.append(node)
|
||||||
|
|
||||||
|
@ -1454,44 +1536,34 @@ def emit_code_with_chunk(
|
||||||
emit_node_func: function to emit node
|
emit_node_func: function to emit node
|
||||||
delete_unused_value_func: function to remove the unused value
|
delete_unused_value_func: function to remove the unused value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# find the offload regions
|
|
||||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
|
||||||
chunk_search = chunk_region_search.search_region()
|
|
||||||
chunk_regions = [i["region"] for i in chunk_search]
|
|
||||||
chunk_dims = [i["dim"] for i in chunk_search]
|
|
||||||
chunk_infos = [i["chunk_info"] for i in chunk_search]
|
|
||||||
|
|
||||||
chunk_starts = [item[0] for item in chunk_regions]
|
|
||||||
chunk_ends = [item[1] for item in chunk_regions]
|
|
||||||
chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos]
|
|
||||||
chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos]
|
|
||||||
within_chunk_region = False
|
|
||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
# find the input and output var names for each offload region
|
# find the chunk regions
|
||||||
# for idx, (start, end) in enumerate(chunk_regions):
|
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||||
# offload_node_list = node_list[start:end + 1]
|
chunk_search = chunk_region_search.search_region()
|
||||||
# inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
|
||||||
# chunk_inputs.append(inputs)
|
|
||||||
# chunk_outputs.append(outputs)
|
|
||||||
|
|
||||||
|
chunk_regions = [i["region"] for i in chunk_search]
|
||||||
|
chunk_starts = [i[0] for i in chunk_regions]
|
||||||
|
chunk_ends = [i[1] for i in chunk_regions]
|
||||||
|
|
||||||
|
chunk_inputs = [i["inputs"] for i in chunk_search]
|
||||||
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
|
||||||
|
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
|
||||||
chunk_inputs_idx = [
|
chunk_inputs_idx = [
|
||||||
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
|
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
|
||||||
]
|
]
|
||||||
chunk_outputs_idx = [
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||||
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs
|
|
||||||
]
|
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||||
chunk_inputs_names = []
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||||
for i in chunk_inputs:
|
chunk_outputs_idx = [
|
||||||
for j in i:
|
_find_idx_by_name(i.name, node_list) for i in chunk_outputs
|
||||||
chunk_inputs_names.append(j.name)
|
]
|
||||||
|
|
||||||
# this flag is to prevent repeated insert of save tensors
|
|
||||||
# hooks definition in ckpt_func
|
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
region_idx = 0
|
region_idx = 0
|
||||||
|
within_chunk_region = False
|
||||||
|
|
||||||
while node_idx < len(node_list):
|
while node_idx < len(node_list):
|
||||||
node = node_list[node_idx]
|
node = node_list[node_idx]
|
||||||
|
|
||||||
|
@ -1500,20 +1572,23 @@ def emit_code_with_chunk(
|
||||||
region_idx = chunk_starts.index(node_idx)
|
region_idx = chunk_starts.index(node_idx)
|
||||||
|
|
||||||
# add for loop
|
# add for loop
|
||||||
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
|
||||||
body.append(
|
body.append(
|
||||||
_gen_loop_start(
|
_gen_loop_start(
|
||||||
chunk_input_meta,
|
chunk_inputs[region_idx],
|
||||||
node_list[chunk_ends[region_idx]],
|
chunk_outputs[region_idx],
|
||||||
chunk_dims[region_idx],
|
chunk_outputs_dim[region_idx],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
# replace input var with chunk var
|
# replace input var with chunk var
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||||
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||||
|
if idx == node_idx:
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node))
|
||||||
body[-1] = _replace_name(
|
body[-1] = _replace_name(
|
||||||
body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor"
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
)
|
)
|
||||||
body[-1] = " " + body[-1]
|
body[-1] = " " + body[-1]
|
||||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
|
@ -1526,7 +1601,10 @@ def emit_code_with_chunk(
|
||||||
if node_idx in chunk_ends:
|
if node_idx in chunk_ends:
|
||||||
body.append(
|
body.append(
|
||||||
_gen_loop_end(
|
_gen_loop_end(
|
||||||
node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]
|
chunk_inputs[region_idx],
|
||||||
|
chunk_inputs_non_chunk[region_idx],
|
||||||
|
chunk_outputs[region_idx],
|
||||||
|
chunk_outputs_dim[region_idx], node_list
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
Loading…
Reference in New Issue