|
|
|
@ -20,11 +20,22 @@ from .search_chunk import SearchChunk
|
|
|
|
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): |
|
|
|
|
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: |
|
|
|
|
""" |
|
|
|
|
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :] |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
chunk_dim (int) |
|
|
|
|
chunk_indice_name (str): chunk indice name |
|
|
|
|
shape (List): node shape |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
new_shape (str): return slice |
|
|
|
|
""" |
|
|
|
|
new_shape = "[" |
|
|
|
|
for idx, i in enumerate(shape): |
|
|
|
|
for idx, _ in enumerate(shape): |
|
|
|
|
if idx == chunk_dim: |
|
|
|
|
new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) |
|
|
|
|
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name) |
|
|
|
|
else: |
|
|
|
|
new_shape += ":" |
|
|
|
|
new_shape += ", " |
|
|
|
@ -32,7 +43,26 @@ def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
|
|
|
|
return new_shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): |
|
|
|
|
def _gen_loop_start( |
|
|
|
|
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2 |
|
|
|
|
) -> str: |
|
|
|
|
""" |
|
|
|
|
Generate chunk loop start |
|
|
|
|
|
|
|
|
|
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device) |
|
|
|
|
chunk_size = 32 |
|
|
|
|
for chunk_idx in range(0, 100, 32): |
|
|
|
|
...... |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
chunk_input (List[Node]): chunk input node |
|
|
|
|
chunk_output (Node): chunk output node |
|
|
|
|
chunk_ouput_dim (int): chunk output node chunk dim |
|
|
|
|
chunk_size (int): chunk size. Defaults to 2. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
context (str): generated str |
|
|
|
|
""" |
|
|
|
|
input_node = chunk_input[0] |
|
|
|
|
out_shape = get_node_shape(chunk_output) |
|
|
|
|
out_str = str(list(out_shape)) |
|
|
|
@ -45,8 +75,28 @@ def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_loop_end( |
|
|
|
|
chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list |
|
|
|
|
): |
|
|
|
|
chunk_inputs: List[Node], |
|
|
|
|
chunk_non_compute_inputs: List[Node], |
|
|
|
|
chunk_outputs: Node, |
|
|
|
|
chunk_outputs_dim: int, |
|
|
|
|
node_list: List[Node], |
|
|
|
|
) -> str: |
|
|
|
|
""" |
|
|
|
|
Generate chunk loop end |
|
|
|
|
|
|
|
|
|
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node |
|
|
|
|
output_node = chunk_result; xx = None; xx = None |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
chunk_inputs (List[Node]): chunk input node |
|
|
|
|
chunk_non_compute_inputs (List[Node]): input node without chunk |
|
|
|
|
chunk_outputs (Node): chunk output node |
|
|
|
|
chunk_outputs_dim (int): chunk output node chunk dim |
|
|
|
|
node_list (List) |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
context (str): generated str |
|
|
|
|
""" |
|
|
|
|
chunk_outputs_name = chunk_outputs.name |
|
|
|
|
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) |
|
|
|
|
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape |
|
|
|
@ -76,7 +126,10 @@ def _gen_loop_end(
|
|
|
|
|
return context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_name(context, name_from, name_to): |
|
|
|
|
def _replace_name(context: str, name_from: str, name_to: str) -> str: |
|
|
|
|
""" |
|
|
|
|
replace node name |
|
|
|
|
""" |
|
|
|
|
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] |
|
|
|
|
for p in patterns: |
|
|
|
|
source = p[0] + name_from + p[1] |
|
|
|
@ -86,7 +139,10 @@ def _replace_name(context, name_from, name_to):
|
|
|
|
|
return context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_reshape_size(context, node_name, reshape_size_dict): |
|
|
|
|
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str: |
|
|
|
|
""" |
|
|
|
|
replace reshape size, some may have changed due to chunk |
|
|
|
|
""" |
|
|
|
|
if node_name not in reshape_size_dict: |
|
|
|
|
return context |
|
|
|
|
for size_name, size_value in reshape_size_dict[node_name].items(): |
|
|
|
@ -94,7 +150,17 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
|
|
|
|
return context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_idx, node, body): |
|
|
|
|
def _replace_ones_like( |
|
|
|
|
search_chunk: SearchChunk, |
|
|
|
|
chunk_infos: List[Dict], |
|
|
|
|
region_idx: int, |
|
|
|
|
node_idx: int, |
|
|
|
|
node: Node, |
|
|
|
|
body: List[str], |
|
|
|
|
) -> List[str]: |
|
|
|
|
""" |
|
|
|
|
add chunk slice for new tensor op such as ones like |
|
|
|
|
""" |
|
|
|
|
if "ones_like" in node.name: |
|
|
|
|
meta_node = search_chunk.trace_indice.node_list[node_idx] |
|
|
|
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] |
|
|
|
@ -114,7 +180,16 @@ def _replace_ones_like(search_chunk: SearchChunk, chunk_infos, region_idx, node_
|
|
|
|
|
return body |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body): |
|
|
|
|
def _replace_input_node( |
|
|
|
|
chunk_inputs: List[Node], |
|
|
|
|
region_idx: int, |
|
|
|
|
chunk_inputs_dim: Dict, |
|
|
|
|
node_idx: int, |
|
|
|
|
body: List[str], |
|
|
|
|
) -> List[str]: |
|
|
|
|
""" |
|
|
|
|
add chunk slice for input nodes |
|
|
|
|
""" |
|
|
|
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): |
|
|
|
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): |
|
|
|
|
if idx == node_idx: |
|
|
|
@ -138,7 +213,7 @@ def emit_code_with_chunk(
|
|
|
|
|
""" |
|
|
|
|
Emit code with chunk according to chunk_infos. |
|
|
|
|
|
|
|
|
|
It will generate a for loop in chunk regions, and |
|
|
|
|
It will generate a for loop in chunk regions, and |
|
|
|
|
replace inputs and outputs of regions with chunked variables. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -193,7 +268,7 @@ def emit_code_with_chunk(
|
|
|
|
|
if within_chunk_region: |
|
|
|
|
emit_node_func(node, body) |
|
|
|
|
# replace input var with chunk var |
|
|
|
|
body = _replace_input_var( |
|
|
|
|
body = _replace_input_node( |
|
|
|
|
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body |
|
|
|
|
) |
|
|
|
|
# ones like |
|
|
|
|