mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
8f5a0edfab
commit
378a49dc6c
101
chunk_codegen.py
101
chunk_codegen.py
|
@ -982,24 +982,24 @@ class IndexTracer(object):
|
||||||
|
|
||||||
# reassgin reshape size, some size may have changed due to chunk
|
# reassgin reshape size, some size may have changed due to chunk
|
||||||
chunk_info = self._reassgin_reshape_size(chunk_info)
|
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||||
|
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def _reassgin_reshape_size(self, chunk_info):
|
def _reassgin_reshape_size(self, chunk_info):
|
||||||
chunk_region = chunk_info['region']
|
chunk_region = chunk_info["region"]
|
||||||
reshape_size = {}
|
reshape_size = {}
|
||||||
for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]:
|
for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||||
if any(i in node.name for i in ['reshape', 'view']):
|
if any(i in node.name for i in ["reshape", "view"]):
|
||||||
reshape_args = node.args[1:]
|
reshape_args = node.args[1:]
|
||||||
reshape_log = self.idx_view_list[node]
|
reshape_log = self.idx_view_list[node]
|
||||||
chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim']
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||||
reshape_size[node.name] = {}
|
reshape_size[node.name] = {}
|
||||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||||
if reshape_arg_dim in reshape_log['dim_to']:
|
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||||
continue
|
continue
|
||||||
if reshape_arg_dim == chunk_dim:
|
if reshape_arg_dim == chunk_dim:
|
||||||
reshape_size[node.name][reshape_arg.name] = "chunk_size"
|
reshape_size[node.name][reshape_arg.name] = "chunk_size"
|
||||||
chunk_info['reshape_size'] = reshape_size
|
chunk_info["reshape_size"] = reshape_size
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def _get_reorder_map(self, chunk_info):
|
def _get_reorder_map(self, chunk_info):
|
||||||
|
@ -1213,7 +1213,7 @@ class MemoryEstimator(object):
|
||||||
if node not in chunk_node_dim:
|
if node not in chunk_node_dim:
|
||||||
return 1.0
|
return 1.0
|
||||||
node_shape = _get_node_shape(node)
|
node_shape = _get_node_shape(node)
|
||||||
chunk_dim = chunk_node_dim[node]['chunk_dim']
|
chunk_dim = chunk_node_dim[node]["chunk_dim"]
|
||||||
if chunk_dim is None:
|
if chunk_dim is None:
|
||||||
return 1.0
|
return 1.0
|
||||||
else:
|
else:
|
||||||
|
@ -1381,7 +1381,9 @@ class MemoryEstimator(object):
|
||||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after")
|
self._print_compute_op_mem_log(
|
||||||
|
act_memory_after_node_log, node_list, "after"
|
||||||
|
)
|
||||||
|
|
||||||
# param_memory = parameter_size(gm)
|
# param_memory = parameter_size(gm)
|
||||||
# all_memory = act_memory + param_memory
|
# all_memory = act_memory + param_memory
|
||||||
|
@ -1389,30 +1391,41 @@ class MemoryEstimator(object):
|
||||||
|
|
||||||
|
|
||||||
class ChunkSelector(object):
|
class ChunkSelector(object):
|
||||||
def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge):
|
def __init__(
|
||||||
|
self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge
|
||||||
|
):
|
||||||
self.index_tracer = index_tracer
|
self.index_tracer = index_tracer
|
||||||
self.memory_estimator = memory_estimator
|
self.memory_estimator = memory_estimator
|
||||||
assert stratge in ['min_memory', 'fit_memory']
|
assert stratge in ["min_memory", "fit_memory"]
|
||||||
self.stratge = stratge
|
self.stratge = stratge
|
||||||
self.max_memory = 600 # MB
|
self.max_memory = 600 # MB
|
||||||
|
|
||||||
def _select_best_chunk_region(self, possible_chunk_regions,
|
def _select_best_chunk_region(
|
||||||
chunk_infos, peak_node, max_chunk_region, mem_peak):
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
if self.stratge == 'min_memory':
|
):
|
||||||
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
|
if self.stratge == "min_memory":
|
||||||
elif self.stratge == 'fit_memory':
|
best_region = self._select_min_memory_chunk_region(
|
||||||
|
possible_chunk_regions, chunk_infos
|
||||||
|
)
|
||||||
|
elif self.stratge == "fit_memory":
|
||||||
best_region = self._select_fit_memory_chunk_region(
|
best_region = self._select_fit_memory_chunk_region(
|
||||||
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak)
|
possible_chunk_regions,
|
||||||
|
chunk_infos,
|
||||||
|
peak_node,
|
||||||
|
max_chunk_region,
|
||||||
|
mem_peak,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
return best_region
|
return best_region
|
||||||
|
|
||||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions,
|
def _select_fit_memory_chunk_region(
|
||||||
chunk_infos, peak_node, max_chunk_region, mem_peak):
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
|
):
|
||||||
# stop chunk if max memory satisfy memory limit
|
# stop chunk if max memory satisfy memory limit
|
||||||
if max(mem_peak) < self.max_memory:
|
if max(mem_peak) < self.max_memory:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# remove illegal regions
|
# remove illegal regions
|
||||||
illegal_regions = []
|
illegal_regions = []
|
||||||
for i in possible_chunk_regions:
|
for i in possible_chunk_regions:
|
||||||
|
@ -1421,38 +1434,45 @@ class ChunkSelector(object):
|
||||||
for i in illegal_regions:
|
for i in illegal_regions:
|
||||||
if i in possible_chunk_regions:
|
if i in possible_chunk_regions:
|
||||||
possible_chunk_regions.remove(i)
|
possible_chunk_regions.remove(i)
|
||||||
|
|
||||||
# get mem for chunk region
|
# get mem for chunk region
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_chunk_infos = chunk_infos + [region]
|
cur_chunk_infos = chunk_infos + [region]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, cur_chunk_infos)[0]
|
self.index_tracer.node_list, cur_chunk_infos
|
||||||
cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1]
|
)[0]
|
||||||
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||||
|
]
|
||||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||||
if cur_chunk_region_max_peak < self.max_memory:
|
if cur_chunk_region_max_peak < self.max_memory:
|
||||||
regions_dict.append({
|
regions_dict.append(
|
||||||
"chunk_info": region,
|
{
|
||||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
"chunk_info": region,
|
||||||
"chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]),
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||||
})
|
"chunk_len": self._get_compute_node_num(
|
||||||
|
region["region"][0], region["region"][1]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
# no region found
|
# no region found
|
||||||
if len(regions_dict) == 0:
|
if len(regions_dict) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# select the min chunk len
|
# select the min chunk len
|
||||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||||
best_region_idx = chunk_len.index(min(chunk_len))
|
best_region_idx = chunk_len.index(min(chunk_len))
|
||||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||||
return best_region
|
return best_region
|
||||||
|
|
||||||
def _get_compute_node_num(self, start, end):
|
def _get_compute_node_num(self, start, end):
|
||||||
count = 0
|
count = 0
|
||||||
for i in self.index_tracer.node_list[start: end+1]:
|
for i in self.index_tracer.node_list[start : end + 1]:
|
||||||
if _is_non_compute_node(i):
|
if _is_non_compute_node(i):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||||
max_region_range = 0
|
max_region_range = 0
|
||||||
best_region = None
|
best_region = None
|
||||||
|
@ -1490,7 +1510,9 @@ class ChunkRegionSearch(object):
|
||||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||||
self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory")
|
self.chunk_selector = ChunkSelector(
|
||||||
|
self.index_tracer, self.memory_estimator, stratge="fit_memory"
|
||||||
|
)
|
||||||
|
|
||||||
def _find_peak_node(self, mem_peak):
|
def _find_peak_node(self, mem_peak):
|
||||||
max_value = max(mem_peak)
|
max_value = max(mem_peak)
|
||||||
|
@ -1808,10 +1830,11 @@ def _replace_name(context, name_from, name_to):
|
||||||
def _replace_reshape_size(context, node_name, reshape_size_dict):
|
def _replace_reshape_size(context, node_name, reshape_size_dict):
|
||||||
if node_name not in reshape_size_dict:
|
if node_name not in reshape_size_dict:
|
||||||
return context
|
return context
|
||||||
for size_name, size_value in reshape_size_dict[node_name].items():
|
for size_name, size_value in reshape_size_dict[node_name].items():
|
||||||
context = context.replace(size_name, size_value)
|
context = context.replace(size_name, size_value)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_chunk(
|
def emit_code_with_chunk(
|
||||||
body,
|
body,
|
||||||
ckpt_func,
|
ckpt_func,
|
||||||
|
@ -1883,7 +1906,9 @@ def emit_code_with_chunk(
|
||||||
body[-1] = _replace_name(
|
body[-1] = _replace_name(
|
||||||
body[-1], input_node.name, input_node.name + chunk_slice
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
)
|
)
|
||||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size'])
|
body[-1] = _replace_reshape_size(
|
||||||
|
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
|
||||||
|
)
|
||||||
body[-1] = " " + body[-1]
|
body[-1] = " " + body[-1]
|
||||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue