mirror of https://github.com/hpcaitech/ColossalAI
fix a bug in ones like, dont gen chunk if dim size is 1
parent
5f24f4fd55
commit
7fd3b45af2
|
@ -16,9 +16,9 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
|
||||||
loop = 16
|
loop = 3
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(loop // 4):
|
for _ in range(loop // 2 + 1):
|
||||||
if chunk_size:
|
if chunk_size:
|
||||||
model(node, pair, chunk_size)
|
model(node, pair, chunk_size)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -144,9 +144,7 @@ class IndexTracer(object):
|
||||||
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||||
else:
|
else:
|
||||||
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||||
node_to_trace_source[node_to_dim][node_from_idx].append(
|
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
|
||||||
node_from_dim
|
|
||||||
)
|
|
||||||
# update inputs source
|
# update inputs source
|
||||||
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||||
if node_idx not in node_to_trace_source[node_to_dim]:
|
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||||
|
@ -1472,7 +1470,9 @@ class ChunkSelector(object):
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region)
|
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
||||||
|
self.index_tracer.node_list, cur_region
|
||||||
|
)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
cur_node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
|
@ -1490,7 +1490,7 @@ class ChunkSelector(object):
|
||||||
region["region"][0], region["region"][1]
|
region["region"][0], region["region"][1]
|
||||||
),
|
),
|
||||||
"reorder_chunk_info": cur_region,
|
"reorder_chunk_info": cur_region,
|
||||||
"reorder_node_list": cur_node_list
|
"reorder_node_list": cur_node_list,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# no region found
|
# no region found
|
||||||
|
@ -1508,7 +1508,7 @@ class ChunkSelector(object):
|
||||||
|
|
||||||
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||||
chunk_size = 1
|
chunk_size = 1
|
||||||
reorder_chunk_info = chunk_region_dict['reorder_chunk_info']
|
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||||
reorder_chunk_info["chunk_size"] = chunk_size
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_max_mem = 0
|
cur_chunk_max_mem = 0
|
||||||
# search a region
|
# search a region
|
||||||
|
@ -1517,10 +1517,13 @@ class ChunkSelector(object):
|
||||||
reorder_chunk_info["chunk_size"] = chunk_size
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]
|
cur_mem_peak[
|
||||||
|
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
||||||
|
+ 1
|
||||||
|
]
|
||||||
)
|
)
|
||||||
# search exact size
|
# search exact size
|
||||||
chunk_info = chunk_region_dict["chunk_info"]
|
chunk_info = chunk_region_dict["chunk_info"]
|
||||||
|
@ -1534,13 +1537,13 @@ class ChunkSelector(object):
|
||||||
gap = 4
|
gap = 4
|
||||||
else:
|
else:
|
||||||
gap = 1
|
gap = 1
|
||||||
chunk_info = chunk_region_dict['reorder_chunk_info']
|
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||||
while r >= l + gap:
|
while r >= l + gap:
|
||||||
mid = int((l + r) / 2 + 0.5)
|
mid = int((l + r) / 2 + 0.5)
|
||||||
chunk_info["chunk_size"] = mid
|
chunk_info["chunk_size"] = mid
|
||||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
|
@ -2000,8 +2003,18 @@ def emit_code_with_chunk(
|
||||||
)
|
)
|
||||||
# ones like
|
# ones like
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
|
chunk_dim = chunk_search[region_idx]["node_chunk_dim"][
|
||||||
|
chunk_region_search.index_tracer.node_list[node_idx]
|
||||||
|
]["chunk_dim"]
|
||||||
|
if (
|
||||||
|
_get_node_shape(
|
||||||
|
chunk_region_search.index_tracer.node_list[node_idx]
|
||||||
|
)[chunk_dim]
|
||||||
|
== 1
|
||||||
|
):
|
||||||
|
continue
|
||||||
chunk_slice = _gen_chunk_slice_dim(
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node)
|
chunk_dim, "chunk_idx", _get_node_shape(node)
|
||||||
)
|
)
|
||||||
body[-1] = _replace_name(
|
body[-1] = _replace_name(
|
||||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||||
|
|
Loading…
Reference in New Issue