diff --git a/chunk_codegen.py b/chunk_codegen.py index 0f97f94a9..1e8305ba3 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses): user_to_last_uses[key].remove(n) +class FlowTracer(object): + def __init__(self, gm) -> None: + self.gm = gm + self.nodes_list = list(gm.graph.nodes) + self.flow_trace = {} + + def _add_trace(self, name): + self.flow_trace[name] = [] + + def _add_node(self, trace_name, node): + self.flow_trace[trace_name].append({'node': node, 'inside_depend': [], 'outside_depend': []}) + + def _add_inside_depend(self, flow_name, node, inside_depend_node): + for i in self.flow_trace[flow_name]: + if i['node'] == node: + i['inside_depend'].append(inside_depend_node) + return + raise RuntimeError("node not found") + + def _add_outside_depend(self, flow_name, node, outside_depend_node, outside_depend_trace): + for i in self.flow_trace[flow_name]: + if i['node'] == node: + i['outside_depend'].append({outside_depend_trace: outside_depend_node}) + return + raise RuntimeError("node not found") + + def _init_trace(self): + for i in self.nodes_list: + if i.op == 'placeholder': + self._add_trace(i.name) + self._add_node(i.name, i) + + def _is_non_compute_node(self, node): + if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + def _is_non_compute_node_except_placeholder(self, node): + if any(i in node.op for i in ['get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + def _find_flow_for_node(self, node): + if type(self.nodes_list[0]) != type(node): + return None + if self._is_non_compute_node_except_placeholder(node): + return None + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i['node']: + return name + if any(i in node.name for i in ["ones_like"]): + self._add_trace(node.name) + self._add_node(node.name, node) + return node.name + raise RuntimeError("node not found") + + def _find_first_valid_flow(self, flow): + for i in flow: + if i is not None: + return i + raise RuntimeError("invalid flow") + + def find_node_flow(self, node): + for name, trace in self.flow_trace.items(): + for i in trace: + if node == i['node']: + return name, i + raise RuntimeError("invalid node") + + def get_flow_mix(self, node): + if self._is_non_compute_node(node): + return None + _, node_trace = self.find_node_flow(node) + if len(node_trace['outside_depend']) == 0: + return None + elif len(node_trace['outside_depend']) > 1: + raise NotImplementedError + vars = list(node_trace['outside_depend'][0].values())[0] + return vars + + def get_same_flow_node(self, node_list, node): + name, _ = self.find_node_flow(node) + result = [] + for i in self.flow_trace[name]: + if i['node'] in node_list: + result.append(i['node']) + return result + + def trace_flow(self): + # init trace + self._init_trace() + + for node in self.nodes_list: + # skip if non compute node + if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \ + or self._is_non_compute_node(node): + continue + + node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] + + node_domin_flow = self._find_first_valid_flow(node_input_flows) + self._add_node(node_domin_flow, node) + for node_input_flow, arg in zip(node_input_flows, node.args): + if node_input_flow is None: + continue + elif node_input_flow == node_domin_flow: + self._add_inside_depend(node_domin_flow, node, arg) + else: + self._add_outside_depend(node_domin_flow, node, arg, node_input_flow) + return self.flow_trace + + class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -428,7 +543,7 @@ class IndexTracer(object): if merge_from in trace['idx']: trace['idx'] = [merge_to if i == merge_from else i for i in trace['idx']] - def trace_node_idx(self): + def trace_index(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': self._assign_all_index(node, idx) @@ -684,7 +799,9 @@ class ChunkRegionSearch(object): self.node_list = list(gm.graph.nodes) self.memory_estimator = MemoryEstimator() self.index_tracer = IndexTracer(gm) - self.index_tracer.trace_node_idx() + self.index_tracer.trace_index() + self.flow_tracer = FlowTracer(gm) + self.flow_tracer.trace_flow() def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -729,7 +846,7 @@ class ChunkRegionSearch(object): raise RuntimeError() return chunk_region_start, chunk_region_end - def _not_compute(self, trace, chunk_range, dim_idx): + def _is_not_compute(self, trace, chunk_range, dim_idx): if trace['idx'][dim_idx] not in trace['compute']: return True if trace['idx'][dim_idx] in trace['compute'] and \ @@ -737,6 +854,56 @@ class ChunkRegionSearch(object): return True return False + def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx): + inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) + chunk_info = {'inputs': inputs, 'outputs': outputs} + flow_flag = False + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_var = self.flow_tracer.get_flow_mix(node) + if mix_flow_var is None: + continue + + # if there is a flow mix, op must be in [mul, add, div, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ['mul', 'add']): + for i in node.args: + if type(i) == type(mix_flow_var) and i != mix_flow_var: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO need to move that flow out of the chunk + if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: + flow_flag = True + for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var): + chunk_info['inputs'].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_flag = False + break + else: + raise NotImplementedError("%s not implemented" % node.name) + return flow_flag, chunk_info + + def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + before_trace = input_trace[start_idx] + after_trace = output_trace[end_idx] + free_dim = [] + chunk_infos = [] + for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): + if not (before_trace['idx'][i] == after_trace['idx'][i] and + self._is_not_compute(before_trace, (start_idx, end_idx), i) and + self._is_not_compute(after_trace, (start_idx, end_idx), i) and + self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): + continue + flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) + if flow_flag == None: + continue + chunk_infos.append(chunk_info) + free_dim.append(i) + return free_dim, chunk_infos + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) @@ -748,27 +915,22 @@ class ChunkRegionSearch(object): else: input_trace.append(None) - for before_idx in range(max_chunk_region[0], peak_node): - for after_idx in range(peak_node, max_chunk_region[1] + 1): + for start_idx in range(max_chunk_region[0], peak_node): + for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes if any(op in ['placeholder', 'get_attr', 'output'] for op in - [self.node_list[before_idx].op, self.node_list[after_idx].op]): + [self.node_list[start_idx].op, self.node_list[end_idx].op]): continue if any(any(i in name for i in ['getitem', 'getattr']) for name in - [self.node_list[before_idx].name, self.node_list[after_idx].name]): + [self.node_list[start_idx].name, self.node_list[end_idx].name]): continue # select free dim - before_trace = input_trace[before_idx] - after_trace = output_trace[after_idx] - free_dim = [] - for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): - if (before_trace['idx'][i] == after_trace['idx'][i] and - self._not_compute(before_trace, (before_idx, after_idx), i) and - self._not_compute(after_trace, (before_idx, after_idx), i) and - self.node_list[after_idx].meta['tensor_meta'].shape[i] != 1): - free_dim.append(i) - possible_chunk_region.append({'region': (before_idx, after_idx), 'dim': free_dim}) + free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) + if len(free_dim) > 0: + free_dim = [free_dim[0]] + chunk_info = [chunk_info[0]] + possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v chunk_search = chunk_region_search.search_region() chunk_regions = [i['region'] for i in chunk_search] chunk_dims = [i['dim'] for i in chunk_search] + chunk_infos = [i['chunk_info'] for i in chunk_search] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] - chunk_inputs = [] - chunk_outputs = [] + chunk_inputs = [[j['inputs'][0] for j in i] for i in chunk_infos] + chunk_outputs = [[j['outputs'][0] for j in i] for i in chunk_infos] within_chunk_region = False node_list = list(nodes) # find the input and output var names for each offload region - for idx, (start, end) in enumerate(chunk_regions): - offload_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(offload_node_list) - chunk_inputs.append(inputs) - chunk_outputs.append(outputs) + # for idx, (start, end) in enumerate(chunk_regions): + # offload_node_list = node_list[start:end + 1] + # inputs, outputs = _find_input_and_output_nodes(offload_node_list) + # chunk_inputs.append(inputs) + # chunk_outputs.append(outputs) + chunk_inputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs] chunk_outputs_idx = [[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs] chunk_inputs_names = []