code style

pull/2364/head
oahzxl 2022-12-12 17:25:38 +08:00
parent 31a2c5d09f
commit b7b67c32ad
1 changed files with 28 additions and 42 deletions

View File

@ -232,7 +232,7 @@ class FlowTracer(object):
inputs_dim = []
remove_inputs = []
for input_node in chunk_info['inputs']:
for input_node in chunk_info["inputs"]:
input_dict = {}
for user in input_node.users.keys():
if _is_non_compute_node(user):
@ -252,15 +252,17 @@ class FlowTracer(object):
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
chunk_info['inputs_dim'] = inputs_dim
chunk_info["inputs_dim"] = inputs_dim
for i in remove_inputs:
if i in chunk_info['inputs']:
chunk_info['inputs'].remove(i)
if i in chunk_info["inputs"]:
chunk_info["inputs"].remove(i)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1])
non_chunk_inputs = _find_chunk_all_input_nodes(
self.node_list[start_idx : end_idx + 1]
)
for i in non_chunk_inputs:
if i not in chunk_info['inputs']:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return flow_flag, chunk_info
@ -1371,44 +1373,32 @@ def _get_first_non_single_dim(shape):
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
input_node = chunk_input[0]
out_shape = _get_node_shape(chunk_output)
out_str = str(list(out_shape))
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
% (out_str, input_node.name, input_node.name, chunk_size)
)
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
# node = chunk_input[0]
# node_shape = node.meta["tensor_meta"].shape
# free_shape = [
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
# ]
# chunk_dim = _get_first_non_single_dim(free_shape)
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
# context = (
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
# % (out_shape, node.name, node.name, chunk_size)
# )
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
return context
def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list):
def _gen_loop_end(
chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list
):
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
chunk_slice = _gen_chunk_slice_dim(
chunk_outputs_dim, "chunk_idx", chunk_output_shape
)
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
context += (
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
for chunk_input in (chunk_inputs + chunk_non_compute_inputs):
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all(
[
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
@ -1456,10 +1446,7 @@ def _find_chunk_all_input_nodes(nodes: List[Node]):
input_nodes = []
for node in nodes:
for input_node in node._input_nodes.keys():
if (
input_node not in nodes
and input_node not in input_nodes
):
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
return input_nodes
@ -1549,16 +1536,12 @@ def emit_code_with_chunk(
chunk_inputs = [i["inputs"] for i in chunk_search]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
chunk_inputs_idx = [
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_search]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
chunk_outputs_idx = [
_find_idx_by_name(i.name, node_list) for i in chunk_outputs
]
node_idx = 0
region_idx = 0
@ -1586,7 +1569,9 @@ def emit_code_with_chunk(
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node))
chunk_slice = _gen_chunk_slice_dim(
dim, "chunk_idx", _get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
@ -1604,7 +1589,8 @@ def emit_code_with_chunk(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], node_list
chunk_outputs_dim[region_idx],
node_list,
)
)
within_chunk_region = False