mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
c3a2bf48b4
commit
8a989a0d89
|
@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body):
|
||||||
|
if "ones_like" in node.name:
|
||||||
|
meta_node = search_chunk.trace_index.node_list[node_idx]
|
||||||
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
|
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
|
source_node = meta_node.args[0].args[0]
|
||||||
|
if (
|
||||||
|
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||||
|
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||||
|
)
|
||||||
|
body[-1] = _replace_name(
|
||||||
|
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||||
|
)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body):
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||||
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||||
|
if idx == node_idx:
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||||
|
)
|
||||||
|
body[-1] = _replace_name(
|
||||||
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
|
)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_chunk(
|
def emit_code_with_chunk(
|
||||||
body,
|
body,
|
||||||
nodes,
|
nodes,
|
||||||
|
@ -156,36 +189,14 @@ def emit_code_with_chunk(
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
# replace input var with chunk var
|
# replace input var with chunk var
|
||||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
body = _replace_input_var(
|
||||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||||
if idx == node_idx:
|
)
|
||||||
chunk_slice = _gen_chunk_slice_dim(
|
|
||||||
dim[0], "chunk_idx", get_node_shape(input_node)
|
|
||||||
)
|
|
||||||
body[-1] = _replace_name(
|
|
||||||
body[-1], input_node.name, input_node.name + chunk_slice
|
|
||||||
)
|
|
||||||
# ones like
|
# ones like
|
||||||
if "ones_like" in node.name:
|
body = _replace_ones_like(
|
||||||
meta_node = search_chunk.trace_index.node_list[node_idx]
|
search_chunk, chunk_infos, region_idx, node_idx, node, body
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
)
|
||||||
"chunk_dim"
|
# reassgin reshape size
|
||||||
]
|
|
||||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
|
||||||
source_node = meta_node.args[0].args[0]
|
|
||||||
if (
|
|
||||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
|
||||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
|
|
||||||
"chunk_dim"
|
|
||||||
]
|
|
||||||
is None
|
|
||||||
):
|
|
||||||
chunk_slice = _gen_chunk_slice_dim(
|
|
||||||
chunk_dim, "chunk_idx", get_node_shape(node)
|
|
||||||
)
|
|
||||||
body[-1] = _replace_name(
|
|
||||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
|
||||||
)
|
|
||||||
body[-1] = _replace_reshape_size(
|
body[-1] = _replace_reshape_size(
|
||||||
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue