code style

pull/2364/head
oahzxl 2022-12-12 18:15:47 +08:00
parent 8511d900a8
commit 98f9728e29
1 changed files with 4 additions and 4 deletions

View File

@ -194,7 +194,7 @@ class FlowTracer(object):
if type(i) == type(mix_flow_node) and i != mix_flow_node: if type(i) == type(mix_flow_node) and i != mix_flow_node:
main_flow_var = i main_flow_var = i
# if mix flow is a broadcast in chunk dim, # if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk # TODO: need to move that flow out of the chunk
mix_flow_node_dim = index_tracer._get_node_chunk_dim( mix_flow_node_dim = index_tracer._get_node_chunk_dim(
self.node_list[end_idx], end_dim, node self.node_list[end_idx], end_dim, node
) )
@ -1200,7 +1200,7 @@ class ChunkRegionSearch(object):
continue continue
# it means an index create 2 copy of itself # it means an index create 2 copy of itself
# eg. a = torch.matmul(x, x.transpose(-1, -2)) # eg. a = torch.matmul(x, x.transpose(-1, -2))
# TODO currently remove it, deal with this in future # TODO: currently remove it, deal with this in future
if input_dim1 == input_dim2 and output_dim1 != output_dim2: if input_dim1 == input_dim2 and output_dim1 != output_dim2:
remove_list.append(chunk_infos[idx1]) remove_list.append(chunk_infos[idx1])
remove_list.append(chunk_infos[idx2]) remove_list.append(chunk_infos[idx2])
@ -1216,7 +1216,7 @@ class ChunkRegionSearch(object):
chunk_infos = [] chunk_infos = []
for end_dim, end_trace_idx in enumerate(end_trace["idx"]): for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
if len(start_traces) > 1: if len(start_traces) > 1:
# TODO implement multi input chunk # TODO: implement multi input chunk
continue continue
for start_node, start_trace in start_traces.items(): for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]): for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
@ -1421,7 +1421,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# if a node has a user node which is not in the node list # if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output # we treat that user node as the node receiving the current node output
# TODO it is unsafe to remove non compute node here # TODO: it is unsafe to remove non compute node here
for node in nodes: for node in nodes:
for output_node in node.users.keys(): for output_node in node.users.keys():
if ( if (