support ones_like, add prompt if fit mode search fail

pull/2364/head
oahzxl 2022-12-31 16:29:43 +08:00
parent 80efd70c72
commit 5f24f4fd55
1 changed files with 15 additions and 4 deletions

View File

@ -1406,9 +1406,9 @@ class MemoryEstimator(object):
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
self._print_compute_op_mem_log(
act_memory_after_node_log, node_list, "after"
)
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
@ -1465,6 +1465,9 @@ class ChunkSelector(object):
if i in possible_chunk_regions:
possible_chunk_regions.remove(i)
if len(possible_chunk_regions) == 0:
return None
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
@ -1492,7 +1495,7 @@ class ChunkSelector(object):
)
# no region found
if len(regions_dict) == 0:
return None
raise RuntimeError("Search failed. Try a larger memory threshold.")
# select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict]
@ -1995,6 +1998,14 @@ def emit_code_with_chunk(
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
# ones like
if "ones_like" in node.name:
chunk_slice = _gen_chunk_slice_dim(
chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node)
)
body[-1] = _replace_name(
body[-1], node.args[0].name, node.args[0].name + chunk_slice
)
body[-1] = _replace_reshape_size(
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
)