mirror of https://github.com/hpcaitech/ColossalAI
add possible region search
parent
d9ca2f898d
commit
7330d90745
116
chunk_codegen.py
116
chunk_codegen.py
|
@ -356,7 +356,17 @@ class NodeIndexTracer(object):
|
|||
"idx_to": [new_trace[i] for i in dim_to],
|
||||
"dim_to": dim_to}
|
||||
self.idx_view_list.append(view_dict)
|
||||
|
||||
|
||||
def _merge_equal_idx(self):
|
||||
idx_equal = copy.deepcopy(self.idx_trace_equal)
|
||||
idx_equal.reverse()
|
||||
for idx in idx_equal:
|
||||
merge_to = min(idx)
|
||||
merge_from = max(idx)
|
||||
for trace in self.idx_trace_list:
|
||||
if merge_from in trace['idx']:
|
||||
trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']]
|
||||
|
||||
def trace_node_idx(self):
|
||||
for idx, node in enumerate(self.nodes_list):
|
||||
if node.op == 'placeholder':
|
||||
|
@ -396,6 +406,7 @@ class NodeIndexTracer(object):
|
|||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
self._merge_equal_idx()
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
|
@ -433,6 +444,8 @@ class MemoryEstimator(object):
|
|||
for i in range(len(out_node)):
|
||||
if out_node[i][0] > 0:
|
||||
delete_node.append(out_node[i][1][0])
|
||||
elif nodes_to_delete[i].op == 'placeholder':
|
||||
delete_node.append(nodes_to_delete[i].name)
|
||||
return delete_size, delete_node
|
||||
|
||||
def _get_delete_node_size(self, user, user_to_last_uses):
|
||||
|
@ -516,8 +529,9 @@ class MemoryEstimator(object):
|
|||
active_node_list_log = []
|
||||
not_contiguous_list = []
|
||||
node_list = list(gm.graph.nodes)
|
||||
user_to_last_uses = self._get_last_usr(list(gm.graph.nodes))
|
||||
_delete_free_var_from_last_use(user_to_last_uses)
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
|
||||
use_chunk = all(i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes])
|
||||
chunk_within = False
|
||||
|
@ -535,6 +549,7 @@ class MemoryEstimator(object):
|
|||
if node.op == 'placeholder':
|
||||
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
active_node_list.append(node.name)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
|
@ -549,10 +564,10 @@ class MemoryEstimator(object):
|
|||
act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2)
|
||||
if chunk_within:
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node, user_to_last_uses, chunk_ratio, node_list,
|
||||
node, user_to_last_uses_no_free_var, chunk_ratio, node_list,
|
||||
start_nodes[chunk_region_idx], end_nodes[chunk_region_idx]) / (1024 ** 2)
|
||||
else:
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2)
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var) / (1024 ** 2)
|
||||
|
||||
# log active node
|
||||
self._add_active_node(node, active_node_list)
|
||||
|
@ -572,8 +587,92 @@ class MemoryEstimator(object):
|
|||
self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
|
||||
param_memory = parameter_size(gm)
|
||||
return act_memory + param_memory, param_memory
|
||||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||
|
||||
|
||||
class ChunkRegionSearch(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.node_list = list(gm.graph.nodes)
|
||||
self.memory_estimator = MemoryEstimator()
|
||||
self.index_tracer = NodeIndexTracer(gm)
|
||||
self.index_tracer.trace_node_idx()
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
max_value = max(mem_peak)
|
||||
max_idx = [mem_peak.index(max_value)]
|
||||
return max_idx
|
||||
|
||||
def _get_free_var(self):
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.node_list):
|
||||
if n.op == 'placeholder':
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _get_min_free_var(self, active_node_list, free_vars):
|
||||
min_len = 999
|
||||
for idx, n in enumerate(active_node_list):
|
||||
if idx in free_vars:
|
||||
continue
|
||||
if len(n) < min_len:
|
||||
min_len = len(n)
|
||||
return min_len
|
||||
|
||||
def _search_max_chunk_region(self, active_node, peak_node):
|
||||
free_vars = self._get_free_var()
|
||||
min_var = self._get_min_free_var(active_node, free_vars)
|
||||
|
||||
# from peak_node to free_var
|
||||
chunk_region_start = None
|
||||
for i in range(peak_node, -1, -1):
|
||||
if len(active_node[i]) == min_var:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
# from peak_node to len-2
|
||||
chunk_region_end = None
|
||||
for i in range(peak_node, len(active_node) - 1):
|
||||
if len(active_node[i]) == min_var:
|
||||
chunk_region_end = i - 1
|
||||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
possible_chunk_region = []
|
||||
for before_idx in range(max_chunk_region[0], peak_node):
|
||||
for after_idx in range(peak_node, max_chunk_region[1]):
|
||||
# skip non compute nodes
|
||||
if any(op in ['placeholder', 'get_attr', 'output'] for op in
|
||||
[self.node_list[before_idx].op, self.node_list[after_idx].op]):
|
||||
continue
|
||||
if any(any(i in name for i in ['getitem', 'getattr']) for name in
|
||||
[self.node_list[before_idx].name, self.node_list[after_idx].name]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
before_trace = self.index_tracer.idx_trace_list[before_idx]
|
||||
after_trace = self.index_tracer.idx_trace_list[after_idx]
|
||||
free_dim = []
|
||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
||||
if (before_trace['idx'][i] == after_trace['idx'][i] and
|
||||
before_trace['idx'][i] not in before_trace['compute'] and
|
||||
after_trace['idx'][i] not in after_trace['compute']):
|
||||
free_dim.append(i)
|
||||
possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
|
||||
return possible_chunk_region
|
||||
|
||||
def search_region(self):
|
||||
mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
|
||||
peak_nodes = self._find_peak_node(mem_peak)
|
||||
for idx, peak_node in enumerate(peak_nodes):
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
|
@ -696,6 +795,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
|
||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||
node_index_tracer.trace_node_idx()
|
||||
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||
chunk_region_search.search_region()
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(chunk_regions):
|
||||
|
|
Loading…
Reference in New Issue