mirror of https://github.com/hpcaitech/ColossalAI
finish region search loop
parent
7330d90745
commit
3b7d671206
152
chunk_codegen.py
152
chunk_codegen.py
|
@ -21,7 +21,7 @@ class NodeIndexTracer(object):
|
|||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.nodes_list = list(gm.graph.nodes)
|
||||
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
self.idx_count = -1
|
||||
|
@ -48,9 +48,12 @@ class NodeIndexTracer(object):
|
|||
"""
|
||||
_, compute_from = self._find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self._find_trace_from_node(node_to)
|
||||
for i in compute_from:
|
||||
if i in idx_to and i not in compute_to:
|
||||
compute_to.append(i)
|
||||
for k, v in compute_from.items():
|
||||
if k in idx_to:
|
||||
if k in compute_to:
|
||||
compute_to[k].extend(v)
|
||||
else:
|
||||
compute_to[k] = copy.deepcopy(v)
|
||||
|
||||
def _mark_idx_equal(self, idx1, idx2):
|
||||
"""
|
||||
|
@ -77,7 +80,9 @@ class NodeIndexTracer(object):
|
|||
for d in dim:
|
||||
cur_idx = input_node_idx_trace[d]
|
||||
if cur_idx not in self.idx_trace_list[idx]['compute']:
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
self.idx_trace_list[idx]['compute'][cur_idx] = [idx]
|
||||
else:
|
||||
self.idx_trace_list[idx]['compute'][cur_idx].append(idx)
|
||||
|
||||
def _find_trace_from_node(self, node):
|
||||
"""
|
||||
|
@ -357,6 +362,11 @@ class NodeIndexTracer(object):
|
|||
"dim_to": dim_to}
|
||||
self.idx_view_list.append(view_dict)
|
||||
|
||||
def _remove_duplicate_compute(self):
|
||||
for i in self.idx_trace_list:
|
||||
for k, v in i['compute'].items():
|
||||
i['compute'][k] = list(set(v))
|
||||
|
||||
def _merge_equal_idx(self):
|
||||
idx_equal = copy.deepcopy(self.idx_trace_equal)
|
||||
idx_equal.reverse()
|
||||
|
@ -406,6 +416,8 @@ class NodeIndexTracer(object):
|
|||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
self._remove_duplicate_compute()
|
||||
self._merge_equal_idx()
|
||||
|
||||
|
||||
|
@ -521,6 +533,19 @@ class MemoryEstimator(object):
|
|||
print("")
|
||||
print("\n")
|
||||
|
||||
def _print_compute_op_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
if n.op in ['placeholder', 'get_attr', 'output']:
|
||||
continue
|
||||
if any(i in n.name for i in ['getitem', 'getattr']):
|
||||
continue
|
||||
print("%s:%.2f \t" % (n.name, l), end='')
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None):
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
|
@ -584,8 +609,10 @@ class MemoryEstimator(object):
|
|||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
|
||||
print("with chunk" if use_chunk else "without chunk")
|
||||
self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
|
||||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
|
@ -602,7 +629,7 @@ class ChunkRegionSearch(object):
|
|||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
max_value = max(mem_peak)
|
||||
max_idx = [mem_peak.index(max_value)]
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
|
||||
def _get_free_var(self):
|
||||
|
@ -635,18 +662,35 @@ class ChunkRegionSearch(object):
|
|||
raise RuntimeError()
|
||||
# from peak_node to len-2
|
||||
chunk_region_end = None
|
||||
for i in range(peak_node, len(active_node) - 1):
|
||||
for i in range(peak_node, len(active_node)):
|
||||
if len(active_node[i]) == min_var:
|
||||
chunk_region_end = i - 1
|
||||
chunk_region_end = i
|
||||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _not_compute(self, trace, chunk_range, dim_idx):
|
||||
if trace['idx'][dim_idx] not in trace['compute']:
|
||||
return True
|
||||
if trace['idx'][dim_idx] in trace['compute'] and \
|
||||
all(i < chunk_range[0] or i > chunk_range[1] for i in trace['compute'][trace['idx'][dim_idx]]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||
input_trace = []
|
||||
for i, n in enumerate(self.node_list):
|
||||
if len(n.args) > 0 and n.op != 'output':
|
||||
input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
|
||||
input_trace.append(output_trace[input_idx])
|
||||
else:
|
||||
input_trace.append(None)
|
||||
|
||||
for before_idx in range(max_chunk_region[0], peak_node):
|
||||
for after_idx in range(peak_node, max_chunk_region[1]):
|
||||
for after_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if any(op in ['placeholder', 'get_attr', 'output'] for op in
|
||||
[self.node_list[before_idx].op, self.node_list[after_idx].op]):
|
||||
|
@ -656,23 +700,59 @@ class ChunkRegionSearch(object):
|
|||
continue
|
||||
|
||||
# select free dim
|
||||
before_trace = self.index_tracer.idx_trace_list[before_idx]
|
||||
after_trace = self.index_tracer.idx_trace_list[after_idx]
|
||||
before_trace = input_trace[before_idx]
|
||||
after_trace = output_trace[after_idx]
|
||||
free_dim = []
|
||||
for i in range(min(len(before_trace['idx']), len(after_trace['idx']))):
|
||||
if (before_trace['idx'][i] == after_trace['idx'][i] and
|
||||
before_trace['idx'][i] not in before_trace['compute'] and
|
||||
after_trace['idx'][i] not in after_trace['compute']):
|
||||
self._not_compute(before_trace, (before_idx, after_idx), i) and
|
||||
self._not_compute(after_trace, (before_idx, after_idx), i) and
|
||||
self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1):
|
||||
free_dim.append(i)
|
||||
possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim})
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||
max_region_range = 0
|
||||
best_regions = None
|
||||
for i in possible_chunk_regions:
|
||||
if i['region'][1] - i['region'][0] > max_region_range:
|
||||
best_regions = i
|
||||
max_region_range = i['region'][1] - i['region'][0]
|
||||
return best_regions
|
||||
|
||||
def _step_search(self, peak_node, active_node):
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
sorted_init_mem_peak = sorted(init_mem_peak)
|
||||
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_region(self):
|
||||
mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
|
||||
peak_nodes = self._find_peak_node(mem_peak)
|
||||
for idx, peak_node in enumerate(peak_nodes):
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
chunk_regions = []
|
||||
init_mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
chunk_region = self._step_search(peak_node, active_node)
|
||||
if chunk_region is None or len(chunk_region['dim']) == 0:
|
||||
break
|
||||
|
||||
chunk_regions.append(chunk_region)
|
||||
mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.gm, [i['region'][0] for i in chunk_regions],
|
||||
[i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions))
|
||||
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
|
||||
return chunk_regions
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
|
@ -696,11 +776,12 @@ def _get_first_non_single_dim(shape):
|
|||
raise RuntimeError("can not get first non single dim for shape", shape)
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
|
||||
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2):
|
||||
if len(chunk_input_meta) == 1:
|
||||
node = chunk_input_meta[0]
|
||||
node_shape = node.meta['tensor_meta'].shape
|
||||
chunk_dim = _get_first_non_single_dim(node_shape)
|
||||
free_shape = [node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))]
|
||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||
out_shape = str(list(chunk_output.meta['tensor_meta'].shape))
|
||||
|
||||
|
@ -713,12 +794,13 @@ def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
|
|||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list):
|
||||
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim):
|
||||
chunk_inputs_name = chunk_inputs[0].name
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape
|
||||
chunk_dim = _get_first_non_single_dim(chunk_output_shape)
|
||||
free_shape = [chunk_output_shape[i] if i in chunk_dim else 1 for i in range(len(chunk_output_shape))]
|
||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||
|
||||
|
@ -780,7 +862,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
"""
|
||||
|
||||
# find the offload regions
|
||||
chunk_regions = [(58, 62)]
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||
chunk_search = chunk_region_search.search_region()
|
||||
chunk_regions = [i['region'] for i in chunk_search]
|
||||
chunk_dims = [i['dim'] for i in chunk_search]
|
||||
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
|
@ -789,16 +875,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
|
||||
node_list = list(nodes)
|
||||
|
||||
memory_estimator = MemoryEstimator()
|
||||
memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2])
|
||||
memory_estimator.estimate_chunk_inference_mem(meta_graph)
|
||||
|
||||
node_index_tracer = NodeIndexTracer(meta_graph)
|
||||
node_index_tracer.trace_node_idx()
|
||||
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||
chunk_region_search.search_region()
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(chunk_regions):
|
||||
offload_node_list = node_list[start:end + 1]
|
||||
|
@ -824,13 +900,13 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
|
||||
# add for loop
|
||||
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
||||
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]]))
|
||||
body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]], chunk_dims[region_idx]))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)')
|
||||
body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
|
@ -840,7 +916,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list))
|
||||
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]))
|
||||
within_chunk_region = False
|
||||
region_idx += 1
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
|||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
|
||||
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output"
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output"
|
||||
|
||||
# test barckward
|
||||
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
||||
|
|
Loading…
Reference in New Issue