diff --git a/chunk_codegen.py b/chunk_codegen.py index ba83f7fec..47cda0f8e 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -21,7 +21,7 @@ class NodeIndexTracer(object): def __init__(self, gm) -> None: self.gm = gm self.nodes_list = list(gm.graph.nodes) - self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] + self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 @@ -48,9 +48,12 @@ class NodeIndexTracer(object): """ _, compute_from = self._find_trace_from_node(node_from) idx_to, compute_to = self._find_trace_from_node(node_to) - for i in compute_from: - if i in idx_to and i not in compute_to: - compute_to.append(i) + for k, v in compute_from.items(): + if k in idx_to: + if k in compute_to: + compute_to[k].extend(v) + else: + compute_to[k] = copy.deepcopy(v) def _mark_idx_equal(self, idx1, idx2): """ @@ -77,7 +80,9 @@ class NodeIndexTracer(object): for d in dim: cur_idx = input_node_idx_trace[d] if cur_idx not in self.idx_trace_list[idx]['compute']: - self.idx_trace_list[idx]['compute'].append(cur_idx) + self.idx_trace_list[idx]['compute'][cur_idx] = [idx] + else: + self.idx_trace_list[idx]['compute'][cur_idx].append(idx) def _find_trace_from_node(self, node): """ @@ -357,6 +362,11 @@ class NodeIndexTracer(object): "dim_to": dim_to} self.idx_view_list.append(view_dict) + def _remove_duplicate_compute(self): + for i in self.idx_trace_list: + for k, v in i['compute'].items(): + i['compute'][k] = list(set(v)) + def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal.reverse() @@ -406,6 +416,8 @@ class NodeIndexTracer(object): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + + self._remove_duplicate_compute() self._merge_equal_idx() @@ -521,6 +533,19 @@ class MemoryEstimator(object): print("") print("\n") + def _print_compute_op_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + if n.op in ['placeholder', 'get_attr', 'output']: + continue + if any(i in n.name for i in ['getitem', 'getattr']): + continue + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes=None, end_nodes=None, chunk_dims=None, chunk_sizes=None): act_memory = 0.0 act_memory_peak_log = [] @@ -584,8 +609,10 @@ class MemoryEstimator(object): active_node_list_log.append(copy.deepcopy(active_node_list)) print("with chunk" if use_chunk else "without chunk") - self._print_mem_log(act_memory_peak_log, node_list, "peak") - self._print_mem_log(act_memory_after_node_log, node_list, "after") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -602,7 +629,7 @@ class ChunkRegionSearch(object): def _find_peak_node(self, mem_peak): max_value = max(mem_peak) - max_idx = [mem_peak.index(max_value)] + max_idx = mem_peak.index(max_value) return max_idx def _get_free_var(self): @@ -635,18 +662,35 @@ class ChunkRegionSearch(object): raise RuntimeError() # from peak_node to len-2 chunk_region_end = None - for i in range(peak_node, len(active_node) - 1): + for i in range(peak_node, len(active_node)): if len(active_node[i]) == min_var: - chunk_region_end = i - 1 + chunk_region_end = i break if i in free_vars or i == 0: raise RuntimeError() return chunk_region_start, chunk_region_end + def _not_compute(self, trace, chunk_range, dim_idx): + if trace['idx'][dim_idx] not in trace['compute']: + return True + if trace['idx'][dim_idx] in trace['compute'] and \ + all(i < chunk_range[0] or i > chunk_range[1] for i in trace['compute'][trace['idx'][dim_idx]]): + return True + return False + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] + output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + input_trace = [] + for i, n in enumerate(self.node_list): + if len(n.args) > 0 and n.op != 'output': + input_idx = _find_idx_by_name(n.args[0].name, self.node_list) + input_trace.append(output_trace[input_idx]) + else: + input_trace.append(None) + for before_idx in range(max_chunk_region[0], peak_node): - for after_idx in range(peak_node, max_chunk_region[1]): + for after_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if any(op in ['placeholder', 'get_attr', 'output'] for op in [self.node_list[before_idx].op, self.node_list[after_idx].op]): @@ -656,23 +700,59 @@ class ChunkRegionSearch(object): continue # select free dim - before_trace = self.index_tracer.idx_trace_list[before_idx] - after_trace = self.index_tracer.idx_trace_list[after_idx] + before_trace = input_trace[before_idx] + after_trace = output_trace[after_idx] free_dim = [] for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): if (before_trace['idx'][i] == after_trace['idx'][i] and - before_trace['idx'][i] not in before_trace['compute'] and - after_trace['idx'][i] not in after_trace['compute']): + self._not_compute(before_trace, (before_idx, after_idx), i) and + self._not_compute(after_trace, (before_idx, after_idx), i) and + self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1): free_dim.append(i) possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) return possible_chunk_region + def _search_best_chunk_region(self, possible_chunk_regions): + max_region_range = 0 + best_regions = None + for i in possible_chunk_regions: + if i['region'][1] - i['region'][0] > max_region_range: + best_regions = i + max_region_range = i['region'][1] - i['region'][0] + return best_regions + + def _step_search(self, peak_node, active_node): + max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) + return best_chunk_region + + def _stop_search(self, init_mem_peak, mem_peak): + sorted_init_mem_peak = sorted(init_mem_peak) + if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: + return True + return False + def search_region(self): - mem_peak, mem_after, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) - peak_nodes = self._find_peak_node(mem_peak) - for idx, peak_node in enumerate(peak_nodes): - max_chunk_region = self._search_max_chunk_region(active_node, peak_node) - possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + chunk_regions = [] + init_mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem(self.gm) + mem_peak = init_mem_peak + + while True: + peak_node = self._find_peak_node(mem_peak) + chunk_region = self._step_search(peak_node, active_node) + if chunk_region is None or len(chunk_region['dim']) == 0: + break + + chunk_regions.append(chunk_region) + mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem( + self.gm, [i['region'][0] for i in chunk_regions], + [i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions)) + + if self._stop_search(init_mem_peak, mem_peak): + break + + return chunk_regions def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -696,11 +776,12 @@ def _get_first_non_single_dim(shape): raise RuntimeError("can not get first non single dim for shape", shape) -def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): +def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2): if len(chunk_input_meta) == 1: node = chunk_input_meta[0] node_shape = node.meta['tensor_meta'].shape - chunk_dim = _get_first_non_single_dim(node_shape) + free_shape = [node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))] + chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) out_shape = str(list(chunk_output.meta['tensor_meta'].shape)) @@ -713,12 +794,13 @@ def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2): return context -def _gen_loop_end(chunk_outputs, chunk_inputs, node_list): +def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim): chunk_inputs_name = chunk_inputs[0].name chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta['tensor_meta'].shape - chunk_dim = _get_first_non_single_dim(chunk_output_shape) + free_shape = [chunk_output_shape[i] if i in chunk_dim else 1 for i in range(len(chunk_output_shape))] + chunk_dim = _get_first_non_single_dim(free_shape) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) @@ -780,7 +862,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(58, 62)] + chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_search = chunk_region_search.search_region() + chunk_regions = [i['region'] for i in chunk_search] + chunk_dims = [i['dim'] for i in chunk_search] + chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -789,16 +875,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v node_list = list(nodes) - memory_estimator = MemoryEstimator() - memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - memory_estimator.estimate_chunk_inference_mem(meta_graph) - - node_index_tracer = NodeIndexTracer(meta_graph) - node_index_tracer.trace_node_idx() - - chunk_region_search = ChunkRegionSearch(meta_graph) - chunk_region_search.search_region() - # find the input and output var names for each offload region for idx, (start, end) in enumerate(chunk_regions): offload_node_list = node_list[start:end + 1] @@ -824,13 +900,13 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # add for loop chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] - body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]])) + body.append(_gen_loop_start(chunk_input_meta, node_list[chunk_ends[region_idx]], chunk_dims[region_idx])) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var if node_idx in chunk_starts: - body[-1] = body[-1].replace("("+ chunk_inputs[region_idx][0].name +")", '(chunk_tensor)') + body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor') body[-1] = ' ' + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) @@ -840,7 +916,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v delete_unused_value_func(node, body, chunk_inputs_names) if node_idx in chunk_ends: - body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list)) + body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) within_chunk_region = False region_idx += 1 diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 39363a80a..88c734903 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -45,8 +45,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output" - assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-6), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-6), "fx_out doesn't comply with original output" # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()