mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
8511d900a8
commit
98f9728e29
|
@ -194,7 +194,7 @@ class FlowTracer(object):
|
||||||
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
||||||
main_flow_var = i
|
main_flow_var = i
|
||||||
# if mix flow is a broadcast in chunk dim,
|
# if mix flow is a broadcast in chunk dim,
|
||||||
# TODO need to move that flow out of the chunk
|
# TODO: need to move that flow out of the chunk
|
||||||
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
||||||
self.node_list[end_idx], end_dim, node
|
self.node_list[end_idx], end_dim, node
|
||||||
)
|
)
|
||||||
|
@ -1200,7 +1200,7 @@ class ChunkRegionSearch(object):
|
||||||
continue
|
continue
|
||||||
# it means an index create 2 copy of itself
|
# it means an index create 2 copy of itself
|
||||||
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
# eg. a = torch.matmul(x, x.transpose(-1, -2))
|
||||||
# TODO currently remove it, deal with this in future
|
# TODO: currently remove it, deal with this in future
|
||||||
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
if input_dim1 == input_dim2 and output_dim1 != output_dim2:
|
||||||
remove_list.append(chunk_infos[idx1])
|
remove_list.append(chunk_infos[idx1])
|
||||||
remove_list.append(chunk_infos[idx2])
|
remove_list.append(chunk_infos[idx2])
|
||||||
|
@ -1216,7 +1216,7 @@ class ChunkRegionSearch(object):
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
||||||
if len(start_traces) > 1:
|
if len(start_traces) > 1:
|
||||||
# TODO implement multi input chunk
|
# TODO: implement multi input chunk
|
||||||
continue
|
continue
|
||||||
for start_node, start_trace in start_traces.items():
|
for start_node, start_trace in start_traces.items():
|
||||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||||
|
@ -1421,7 +1421,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||||
|
|
||||||
# if a node has a user node which is not in the node list
|
# if a node has a user node which is not in the node list
|
||||||
# we treat that user node as the node receiving the current node output
|
# we treat that user node as the node receiving the current node output
|
||||||
# TODO it is unsafe to remove non compute node here
|
# TODO: it is unsafe to remove non compute node here
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for output_node in node.users.keys():
|
for output_node in node.users.keys():
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in New Issue