mirror of https://github.com/hpcaitech/ColossalAI
support ones_like, add prompt if fit mode search fail
parent
80efd70c72
commit
5f24f4fd55
|
@ -1406,9 +1406,9 @@ class MemoryEstimator(object):
|
||||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
self._print_compute_op_mem_log(
|
# self._print_compute_op_mem_log(
|
||||||
act_memory_after_node_log, node_list, "after"
|
# act_memory_after_node_log, node_list, "after"
|
||||||
)
|
# )
|
||||||
|
|
||||||
# param_memory = parameter_size(gm)
|
# param_memory = parameter_size(gm)
|
||||||
# all_memory = act_memory + param_memory
|
# all_memory = act_memory + param_memory
|
||||||
|
@ -1465,6 +1465,9 @@ class ChunkSelector(object):
|
||||||
if i in possible_chunk_regions:
|
if i in possible_chunk_regions:
|
||||||
possible_chunk_regions.remove(i)
|
possible_chunk_regions.remove(i)
|
||||||
|
|
||||||
|
if len(possible_chunk_regions) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
# get mem for chunk region
|
# get mem for chunk region
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
|
@ -1492,7 +1495,7 @@ class ChunkSelector(object):
|
||||||
)
|
)
|
||||||
# no region found
|
# no region found
|
||||||
if len(regions_dict) == 0:
|
if len(regions_dict) == 0:
|
||||||
return None
|
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||||
|
|
||||||
# select the min chunk len
|
# select the min chunk len
|
||||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||||
|
@ -1995,6 +1998,14 @@ def emit_code_with_chunk(
|
||||||
body[-1] = _replace_name(
|
body[-1] = _replace_name(
|
||||||
body[-1], input_node.name, input_node.name + chunk_slice
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
)
|
)
|
||||||
|
# ones like
|
||||||
|
if "ones_like" in node.name:
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node)
|
||||||
|
)
|
||||||
|
body[-1] = _replace_name(
|
||||||
|
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||||
|
)
|
||||||
body[-1] = _replace_reshape_size(
|
body[-1] = _replace_reshape_size(
|
||||||
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
|
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue