code style

pull/2364/head
oahzxl 2022-12-27 14:49:52 +08:00
parent 6be89a3b82
commit a2b4755ce9
1 changed files with 8 additions and 4 deletions

View File

@ -988,7 +988,9 @@ class IndexTracer(object):
def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
chunk_shape = _get_node_shape(chunk_info["outputs"][0])[
chunk_info["outputs_dim"]
]
for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:]
@ -999,7 +1001,9 @@ class IndexTracer(object):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape
reshape_size[node.name][reshape_arg.name] = (
"min(chunk_size, %d - chunk_idx)" % chunk_shape
)
chunk_info["reshape_size"] = reshape_size
return chunk_info
@ -1498,7 +1502,7 @@ class ChunkSelector(object):
else:
gap = 1
while r >= l + gap:
mid = int(l + (r - l)/2)
mid = int(l + (r - l) / 2)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
@ -1938,7 +1942,7 @@ def emit_code_with_chunk(
chunk_inputs[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_size=chunk_search[region_idx]["chunk_size"]
chunk_search[region_idx]["chunk_size"],
)
)