diff --git a/chunk_codegen.py b/chunk_codegen.py index 64bff4a80..b5bb8f185 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -180,7 +180,7 @@ class FlowTracer(object): "args": {}, } flow_block = False - + # TODO don't allow multi outputs now if len(outputs) > 1: flow_block = True @@ -200,7 +200,7 @@ class FlowTracer(object): main_flow_var = i # if mix flow is a broadcast in chunk dim, # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer._get_node_chunk_dim( + mix_flow_node_dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, node ) if mix_flow_node_dim is None: @@ -223,7 +223,7 @@ class FlowTracer(object): if flow_block: flow_block = True return flow_block, chunk_info - + inputs_dim = [] remove_inputs = [] for input_node in chunk_info["inputs"]: @@ -234,7 +234,7 @@ class FlowTracer(object): user_idx = _find_idx_by_name(user.name, self.node_list) dim = None if start_dim <= user_idx < end_idx: - dim = index_tracer._get_node_chunk_dim( + dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, input_node ) elif user_idx == end_idx: @@ -300,10 +300,10 @@ class IndexTracer(object): self.idx_trace_list[idx]["compute"].pop(dim_idx) self.idx_trace_list[idx]["source"].pop(dim_idx) - def _add_dim(self, idx, dim_idx): - self.idx_trace_list[idx]["idx"].insert(dim_idx, self._add_index()) - self.idx_trace_list[idx]["compute"].insert(dim_idx, []) - self.idx_trace_list[idx]["source"].insert(dim_idx, {}) + def _add_dim(self, node_idx, dim_idx): + self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) + self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) + self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) def _transform_index(self, node, node_dim): node_idx = self._find_idx_trace_from_node(node) @@ -659,9 +659,7 @@ class IndexTracer(object): """ self._del_dim(node_idx, -1) self._assign_index_as_input(node, node_idx) - self.idx_trace_list[node_idx]["idx"].insert(node.args[1], self._add_index()) - self.idx_trace_list[node_idx]["compute"].insert(node.args[1], []) - self.idx_trace_list[node_idx]["source"].insert(node.args[1], []) + self._add_dim(node_idx, node.args[1]) def _assign_dropout_index(self, node, node_idx): """ @@ -879,7 +877,7 @@ class IndexTracer(object): return False return True - def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) dim_source = node_from_source[node_from_dim] node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) @@ -888,6 +886,44 @@ class IndexTracer(object): return v return None + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(_get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and node_trace_source[node_dim][input_node_idx] == input_dim + ): + return {node_idx: node_dim} + return {} + + def check_index_duplicate(self, chunk_infos): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + input_dim_after_node.update( + self._find_inherit_dim(input_node, v, self.nodes_list[k]) + ) + + for node in self.nodes_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if _is_non_compute_node_except_placeholder(node): + continue + count = 0 + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(_get_node_shape(node))): + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] == v: + count += 1 + break + if count > 1: + return False + return True + class MemoryEstimator(object): def __init__(self) -> None: @@ -1160,7 +1196,7 @@ class ChunkRegionSearch(object): min_len = len(n) return min_len - def _search_max_chunk_region(self, active_node, peak_node): + def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): free_vars = self._get_free_var() min_var = self._get_min_free_var(active_node, free_vars) @@ -1180,6 +1216,21 @@ class ChunkRegionSearch(object): break if i in free_vars or i == 0: raise RuntimeError() + + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif ( + region[0] <= chunk_region_start <= region[1] + and chunk_region_end > region[1] + ): + chunk_region_start = region[1] + 1 + elif ( + region[0] <= chunk_region_end <= region[1] + and chunk_region_start < region[0] + ): + chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end def _is_not_compute(self, trace, chunk_range, dim_idx): @@ -1192,24 +1243,6 @@ class ChunkRegionSearch(object): return True return False - def _check_duplicate_map(self, chunk_infos): - dim_map = [(i["inputs_dim"], i["outputs_dim"]) for i in chunk_infos] - remove_list = [] - for idx1, (input_dim1, output_dim1) in enumerate(dim_map): - for idx2, (input_dim2, output_dim2) in enumerate(dim_map): - if idx1 == idx2: - continue - # it means an index create 2 copy of itself - # eg. a = torch.matmul(x, x.transpose(-1, -2)) - # TODO: currently remove it, deal with this in future - if input_dim1 == input_dim2 and output_dim1 != output_dim2: - remove_list.append(chunk_infos[idx1]) - remove_list.append(chunk_infos[idx2]) - for i in remove_list: - if i in chunk_infos: - chunk_infos.remove(i) - return chunk_infos - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] @@ -1246,8 +1279,10 @@ class ChunkRegionSearch(object): ) if flow_block: continue + # check index copmute + if not self.index_tracer.check_index_duplicate(chunk_info): + continue chunk_infos.append(chunk_info) - chunk_infos = self._check_duplicate_map(chunk_infos) return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): @@ -1288,9 +1323,13 @@ class ChunkRegionSearch(object): max_region_range = i["region"][1] - i["region"][0] return best_regions - def _step_search(self, mem_peak, active_node): + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region(active_node, peak_node) + max_chunk_region = self._search_max_chunk_region( + active_node, peak_node, chunk_regions + ) + if max_chunk_region == None: + return None possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) @@ -1313,7 +1352,7 @@ class ChunkRegionSearch(object): mem_peak = init_mem_peak while True: - chunk_region = self._step_search(mem_peak, active_node) + chunk_region = self._step_search(mem_peak, active_node, chunk_regions) if chunk_region is None: break diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 99700e1af..ae4653d65 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -46,8 +46,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): non_fx_out = model(node, pair) fx_out = gm(node, pair) - assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output" - assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output" + assert torch.allclose(non_fx_out[0], fx_out[0], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(non_fx_out[1] - fx_out[1])) # test barckward # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()