mirror of https://github.com/hpcaitech/ColossalAI
adapt codegen to prepose node
parent
522f017418
commit
d309e9338b
|
@ -1198,7 +1198,7 @@ class FlowTracer(object):
|
|||
chunk_node_list.remove(n)
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info["inputs"] and i not in prepose_nodes:
|
||||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
return chunk_info
|
||||
|
@ -1425,6 +1425,7 @@ class MemoryEstimator(object):
|
|||
) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
# TODO: adapt to prepose node memory
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
|
@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object):
|
|||
chunk_infos = []
|
||||
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
||||
if len(start_traces) > 1:
|
||||
# TODO: implement multi input chunk
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||
|
@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
|||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
# TODO: it is unsafe to remove non compute node here
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (
|
||||
|
@ -1900,6 +1899,8 @@ def emit_code_with_chunk(
|
|||
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||
|
||||
chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search]
|
||||
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
|
@ -1911,7 +1912,11 @@ def emit_code_with_chunk(
|
|||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
region_idx = chunk_starts.index(node_idx)
|
||||
|
||||
# add prepose nodes
|
||||
for i in chunk_prepose_nodes[region_idx]:
|
||||
prepose_node = node_list[_find_idx_by_name(i.name, node_list)]
|
||||
emit_node_func(prepose_node, body)
|
||||
delete_unused_value_func(prepose_node, body, chunk_inputs_names)
|
||||
# add for loop
|
||||
body.append(
|
||||
_gen_loop_start(
|
||||
|
@ -1922,20 +1927,22 @@ def emit_code_with_chunk(
|
|||
)
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim, "chunk_idx", _get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]):
|
||||
pass
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim, "chunk_idx", _get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
|
|
Loading…
Reference in New Issue