diff --git a/chunk_codegen.py b/chunk_codegen.py index ce7d84917..fc3c88cf9 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -16,16 +16,31 @@ def _delete_free_var_from_last_use(user_to_last_uses): if n.op == 'placeholder': user_to_last_uses[key].remove(n) + def _get_node_shape(node): if hasattr(node.meta['tensor_meta'], "shape"): return node.meta['tensor_meta'].shape return None +def _is_non_compute_node(node): + if any(i in node.op for i in ['placeholder', 'get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + +def _is_non_compute_node_except_placeholder(node): + if any(i in node.op for i in ['get_attr', 'output']) or \ + any(i in node.name for i in ['getitem', 'getattr']): + return True + return False + + class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm - self.nodes_list = list(gm.graph.nodes) + self.node_list = list(gm.graph.nodes) self.flow_trace = {} def _add_trace(self, name): @@ -49,7 +64,7 @@ class FlowTracer(object): raise RuntimeError("node not found") def _init_trace(self): - for i in self.nodes_list: + for i in self.node_list: if i.op == 'placeholder': self._add_trace(i.name) self._add_node(i.name, i) @@ -67,7 +82,7 @@ class FlowTracer(object): return False def _find_flow_for_node(self, node): - if type(self.nodes_list[0]) != type(node): + if type(self.node_list[0]) != type(node): return None if self._is_non_compute_node_except_placeholder(node): return None @@ -117,7 +132,7 @@ class FlowTracer(object): # init trace self._init_trace() - for node in self.nodes_list: + for node in self.node_list: # skip if non compute node if all(type(arg) != type(node) or self._is_non_compute_node_except_placeholder(arg) for arg in node.args) \ or self._is_non_compute_node(node): @@ -135,6 +150,41 @@ class FlowTracer(object): else: self._add_outside_depend(node_domin_flow, node, arg, node_input_flow) return self.flow_trace + + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = _find_chunk_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) + chunk_info = {'region': (start_idx, end_idx), + 'inputs': inputs, 'inputs_dim': start_dim, + 'outputs': outputs, 'outputs_dim': end_dim, + 'args': {}} + flow_flag = False + + for idx in range(start_idx, end_idx + 1): + node = self.node_list[idx] + mix_flow_var = self.get_flow_mix(node) + if mix_flow_var is None: + continue + + # if there is a flow mix, op must be in [mul, add, div, matmul] + # element-wise op requires dim to be equal in every dim + if any(n in node.name for n in ['mul', 'add']): + for i in node.args: + if type(i) == type(mix_flow_var) and i != mix_flow_var: + main_flow_var = i + # if mix flow is a broadcast in chunk dim, + # TODO need to move that flow out of the chunk + if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: + flow_flag = True + for i in self.get_same_flow_node(chunk_info['inputs'], mix_flow_var): + chunk_info['inputs'].remove(i) + # else, we need to chunk mix var as well + else: + # TODO chunk another value + flow_flag = False + break + else: + raise NotImplementedError("%s not implemented" % node.name) + return flow_flag, chunk_info class IndexTracer(object): @@ -153,7 +203,7 @@ class IndexTracer(object): cur_trace = { 'idx': [None for _ in range(len(_get_node_shape(n)))], 'compute': [[] for _ in range(len(_get_node_shape(n)))], - 'source': [[] for _ in range(len(_get_node_shape(n)))], + 'source': [{} for _ in range(len(_get_node_shape(n)))], } else: cur_trace = {'idx': [], 'compute': [], 'source': []} @@ -178,7 +228,7 @@ class IndexTracer(object): def _add_dim(self, idx, dim_idx): self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index()) self.idx_trace_list[idx]['compute'].insert(dim_idx, []) - self.idx_trace_list[idx]['source'].insert(dim_idx, []) + self.idx_trace_list[idx]['source'].insert(dim_idx, {}) def _transform_index(self, node, node_dim): node_idx = self._find_idx_trace_from_node(node) @@ -192,10 +242,7 @@ class IndexTracer(object): node_to_trace = self._find_trace_from_node(node_to) node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim] node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim]) - node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) - node_to_trace['source'][node_to_dim] = [] - node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) - node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) def _inherit_all_computation(self, node_from, node_to): node_from_compute = self._find_compute_trace_from_node(node_from) @@ -205,14 +252,16 @@ class IndexTracer(object): self._add_source(node_from, i, node_to, i) node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - def _add_source(self, node_from, node_from_dim, node_to, node_to_dim): + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): node_from_dim = self._transform_index(node_from, node_from_dim) node_from_trace = self._find_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) node_to_trace = self._find_trace_from_node(node_to) node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) - node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) - node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + if init: + node_to_trace['source'][node_to_dim] = {} + node_to_trace['source'][node_to_dim][node_from_idx] = node_from_dim + node_to_trace['source'][node_to_dim].update(node_from_trace['source'][node_from_dim]) def _mark_computation_from_node(self, node_from, node_to, exclude=None): if exclude == None: @@ -485,11 +534,11 @@ class IndexTracer(object): source_idx = left_str.index(right_indice) self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx) - for i in sum_index: - for left_idx, left_str in enumerate(left): - if i in left_str: - self._mark_computation(node, idx, left_str.index(i)) - break + # for i in sum_index: + # for left_idx, left_str in enumerate(left): + # if i in left_str: + # self._mark_computation(node, idx, left_str.index(i)) + # break def _assign_softmax_index(self, node, idx): """ @@ -679,18 +728,56 @@ class IndexTracer(object): raise NotImplementedError(node.op, "op not implemented yet!") # self._merge_equal_idx() - def check_index(self, trace_idx, start_idx, end_idx): - for i in range(start_idx, end_idx + 1): - cur_idx = self.idx_trace_list[i]['idx'] - cur_compute = self.idx_trace_list[i]['compute'] - if trace_idx in cur_compute: - for j in cur_compute[trace_idx]: - if j < start_idx or j > end_idx: - return False - # same_idx = [1 if j == trace_idx else 0 for j in cur_idx] - # if sum(same_idx) > 1: - # return False + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list) + end_node_trace = self._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace['source'][end_dim] + sorted_source = sorted(end_node_trace_source.items(), key=lambda d:d[0], reverse=True) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and node_dim == start_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self._find_trace_from_node(end_node) + end_node_compute = end_node_trace['compute'][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False return True + # end_node_trace_source = end_node_trace['source'][end_dim] + # for node_idx, node_dim in end_node_trace_source.items(): + # if node_idx < start_node_idx or node_idx > end_node_idx: + # continue + # compute_list = self.idx_trace_list[node_idx]['compute'][node_dim] + # if any(start_node_idx <= i <= end_node_idx for i in compute_list): + # return False + # return True + class MemoryEstimator(object): def __init__(self) -> None: @@ -951,88 +1038,81 @@ class ChunkRegionSearch(object): return True return False - def _detect_flow(self, before_trace, after_trace, start_idx, end_idx, dim_idx): - inputs, outputs = _find_input_and_output_nodes(self.node_list[start_idx:end_idx + 1]) - chunk_info = {'inputs': inputs, 'outputs': outputs} - flow_flag = False - - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_var = self.flow_tracer.get_flow_mix(node) - if mix_flow_var is None: - continue - - # if there is a flow mix, op must be in [mul, add, div, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ['mul', 'add']): - for i in node.args: - if type(i) == type(mix_flow_var) and i != mix_flow_var: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO need to move that flow out of the chunk - if mix_flow_var.meta['tensor_meta'].shape[dim_idx] == 1: - flow_flag = True - for i in self.flow_tracer.get_same_flow_node(chunk_info['inputs'], mix_flow_var): - chunk_info['inputs'].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_flag = False - break - else: - raise NotImplementedError("%s not implemented" % node.name) - return flow_flag, chunk_info + def _check_duplicate_map(self, chunk_infos): + dim_map = [(i['inputs_dim'], i['outputs_dim']) for i in chunk_infos] + remove_list = [] + for idx1, (input_dim1, output_dim1) in enumerate(dim_map): + for idx2, (input_dim2, output_dim2) in enumerate(dim_map): + if idx1 == idx2: + continue + # it means an index create 2 copy of itself + # eg. a = torch.matmul(x, x.transpose(-1, -2)) + # TODO currently remove it, deal with this in future + if input_dim1 == input_dim2 and output_dim1 != output_dim2: + remove_list.append(chunk_infos[idx1]) + remove_list.append(chunk_infos[idx2]) + for i in remove_list: + if i in chunk_infos: + chunk_infos.remove(i) + return chunk_infos def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): - before_trace = input_trace[start_idx] - after_trace = output_trace[end_idx] - free_dim = [] + start_traces = input_trace[start_idx] + end_trace = output_trace[end_idx] + end_node = self.node_list[end_idx] chunk_infos = [] - for i in range(min(len(before_trace['idx']), len(after_trace['idx']))): - if not (before_trace['idx'][i] == after_trace['idx'][i] and - self._is_not_compute(before_trace, (start_idx, end_idx), i) and - self._is_not_compute(after_trace, (start_idx, end_idx), i) and - self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): + for end_dim, end_trace_idx in enumerate(end_trace['idx']): + if len(start_traces) > 1: + # TODO implement multi input chunk continue - if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx): - continue - flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) - if flow_flag == None: - continue - chunk_infos.append(chunk_info) - free_dim.append(i) - return free_dim, chunk_infos + for start_node, start_trace in start_traces.items(): + for start_dim, start_trace_idx in enumerate(start_trace['idx']): + # must be same trace idx + if start_trace_idx != end_trace_idx: + continue + # dim size cannot be 1 + if _get_node_shape(end_node)[end_dim] == 1 or \ + _get_node_shape(start_node)[start_dim] == 1: + continue + # check index source align + if not self.index_tracer.check_index_source( + start_dim, start_node, start_idx, end_dim, end_node): + continue + # check index copmute + if not self.index_tracer.check_index_compute( + start_idx, end_dim, end_node, end_idx): + continue + # detect flow meet + flow_flag, chunk_info = self.flow_tracer._detect_flow( + start_idx, start_dim, end_idx, end_dim) + if flow_flag: + continue + chunk_infos.append(chunk_info) + chunk_infos = self._check_duplicate_map(chunk_infos) + return chunk_infos def _search_possible_chunk_regions(self, max_chunk_region, peak_node): possible_chunk_region = [] output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) - input_trace = [] - for i, n in enumerate(self.node_list): - if len(n.args) > 0 and n.op != 'output': - if isinstance(n.args[0], str): - input_idx = _find_idx_by_name(n.args[1].name, self.node_list) - else: - input_idx = _find_idx_by_name(n.args[0].name, self.node_list) - input_trace.append(output_trace[input_idx]) - else: - input_trace.append(None) + input_trace = [] # trace of a node's input nodes + for _, n in enumerate(self.node_list): + cur_trace = {} + for arg in n.args: + if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(arg): + cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + input_trace.append(cur_trace) - for start_idx in range(max_chunk_region[0], peak_node): + for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes - if any(op in ['placeholder', 'get_attr', 'output'] for op in - [self.node_list[start_idx].op, self.node_list[end_idx].op]): - continue - if any(any(i in name for i in ['getitem', 'getattr']) for name in - [self.node_list[start_idx].name, self.node_list[end_idx].name]): + if _is_non_compute_node(self.node_list[start_idx]) or \ + _is_non_compute_node(self.node_list[end_idx]): continue # select free dim - free_dim, chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) - if len(free_dim) > 0: - free_dim = [free_dim[0]] - chunk_info = [chunk_info[0]] - possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) + chunk_info = self._find_free_dim(input_trace, output_trace, start_idx, end_idx) + if len(chunk_info) > 0: + possible_chunk_region.extend(chunk_info) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -1044,7 +1124,8 @@ class ChunkRegionSearch(object): max_region_range = i['region'][1] - i['region'][0] return best_regions - def _step_search(self, peak_node, active_node): + def _step_search(self, mem_peak, active_node): + peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region(active_node, peak_node) possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) @@ -1062,19 +1143,16 @@ class ChunkRegionSearch(object): mem_peak = init_mem_peak while True: - peak_node = self._find_peak_node(mem_peak) - chunk_region = self._step_search(peak_node, active_node) - if chunk_region is None or len(chunk_region['dim']) == 0: + chunk_region = self._step_search(mem_peak, active_node) + if chunk_region is None: break chunk_regions.append(chunk_region) mem_peak, _, active_node = self.memory_estimator.estimate_chunk_inference_mem( self.gm, [i['region'][0] for i in chunk_regions], - [i['region'][1] for i in chunk_regions], [i['dim'][0] for i in chunk_regions], [1] * len(chunk_regions)) - + [i['region'][1] for i in chunk_regions], [i['inputs_dim'] for i in chunk_regions], [1] * len(chunk_regions)) if self._stop_search(init_mem_peak, mem_peak): break - return chunk_regions @@ -1164,6 +1242,35 @@ def _find_input_and_output_nodes(nodes: List[Node]): return input_nodes, output_nodes +def _find_chunk_input_and_output_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + if input_node not in nodes and input_node not in input_nodes \ + and not _is_non_compute_node_except_placeholder(input_node): + input_nodes.append(input_node) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + # TODO it is unsafe to remove non compute node here + for node in nodes: + for output_node in node.users.keys(): + if output_node not in nodes and node not in output_nodes \ + and not _is_non_compute_node_except_placeholder(input_node): + output_nodes.append(node) + + return input_nodes, output_nodes + + def _find_idx_by_name(name, nodes_list): for idx, node in enumerate(nodes_list): if node.name == name: