mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
c3a2bf48b4
commit
8a989a0d89
|
@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
|||
return context
|
||||
|
||||
|
||||
def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body):
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_index.node_list[node_idx]
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
|
||||
is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body):
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(
|
||||
body,
|
||||
nodes,
|
||||
|
@ -156,36 +189,14 @@ def emit_code_with_chunk(
|
|||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body = _replace_input_var(
|
||||
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||
)
|
||||
# ones like
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_index.node_list[node_idx]
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
||||
"chunk_dim"
|
||||
]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
|
||||
"chunk_dim"
|
||||
]
|
||||
is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||
)
|
||||
body = _replace_ones_like(
|
||||
search_chunk, chunk_infos, region_idx, node_idx, node, body
|
||||
)
|
||||
# reassgin reshape size
|
||||
body[-1] = _replace_reshape_size(
|
||||
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue