diff --git a/chunk_codegen.py b/chunk_codegen.py index 88d917809..22d48f5d6 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -194,7 +194,7 @@ class FlowTracer(object): if type(i) == type(mix_flow_node) and i != mix_flow_node: main_flow_var = i # if mix flow is a broadcast in chunk dim, - # TODO need to move that flow out of the chunk + # TODO: need to move that flow out of the chunk mix_flow_node_dim = index_tracer._get_node_chunk_dim( self.node_list[end_idx], end_dim, node ) @@ -1200,7 +1200,7 @@ class ChunkRegionSearch(object): continue # it means an index create 2 copy of itself # eg. a = torch.matmul(x, x.transpose(-1, -2)) - # TODO currently remove it, deal with this in future + # TODO: currently remove it, deal with this in future if input_dim1 == input_dim2 and output_dim1 != output_dim2: remove_list.append(chunk_infos[idx1]) remove_list.append(chunk_infos[idx2]) @@ -1216,7 +1216,7 @@ class ChunkRegionSearch(object): chunk_infos = [] for end_dim, end_trace_idx in enumerate(end_trace["idx"]): if len(start_traces) > 1: - # TODO implement multi input chunk + # TODO: implement multi input chunk continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): @@ -1421,7 +1421,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output - # TODO it is unsafe to remove non compute node here + # TODO: it is unsafe to remove non compute node here for node in nodes: for output_node in node.users.keys(): if (