diff --git a/chunk_codegen.py b/chunk_codegen.py index 4b8882afc..8477fe9a1 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -25,6 +25,7 @@ class NodeIndexTracer(object): self.nodes_list = list(gm.graph.nodes) self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] + self.idx_view_list = [] self.idx_count = 1 def add_index(self): @@ -35,7 +36,7 @@ class NodeIndexTracer(object): _, compute_from = self.find_trace_from_node(node_from) idx_to, compute_to = self.find_trace_from_node(node_to) for i in compute_from: - if i in idx_to: + if i in idx_to and i not in compute_to: compute_to.append(i) def mark_idx_equal(self, idx1, idx2): @@ -47,7 +48,8 @@ class NodeIndexTracer(object): dim = [dim] for d in dim: cur_idx = input_node_idx_trace[d] - self.idx_trace_list[idx]['compute'].append(cur_idx) + if cur_idx not in self.idx_trace_list[idx]['compute']: + self.idx_trace_list[idx]['compute'].append(cur_idx) def find_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) @@ -56,8 +58,11 @@ class NodeIndexTracer(object): def find_idx_trace_from_node(self, node): node_idx = _find_idx_by_name(node.name, self.nodes_list) - node_idx_trace = self.idx_trace_list[node_idx]['idx'] - return node_idx_trace + return self.idx_trace_list[node_idx]['idx'] + + def find_compute_trace_from_node(self, node): + node_idx = _find_idx_by_name(node.name, self.nodes_list) + return self.idx_trace_list[node_idx]['compute'] def assign_index_as_input(self, node, node_idx): input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) @@ -82,6 +87,18 @@ class NodeIndexTracer(object): new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self.inherit_computation(node.args[0], node) + + def assign_permute_index(self, node, node_idx): + permute_dim = node.args[1:] + input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + for idx, d in enumerate(permute_dim): + new_idx_trace[idx] = input_node_idx_trace[d] + + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self.inherit_computation(node.args[0], node) def assign_linear_index(self, node, node_idx): input_node, weight, bias = node.args @@ -100,10 +117,99 @@ class NodeIndexTracer(object): bias_idx_trace = self.find_idx_trace_from_node(bias) self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + def assign_matmul_index(self, node, node_idx): + matmul_left, matmul_right = node.args + matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) + matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) + + assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) + new_idx_trace = copy.deepcopy(matmul_left_idx_trace) + new_idx_trace[-1] = matmul_right_idx_trace[-1] + self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + self.inherit_computation(matmul_left, node) + self.inherit_computation(matmul_right, node) + self.mark_computation(node, node_idx, [-1]) + self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + def assign_layernorm_index(self, node, idx): self.assign_index_as_input(node, idx) + self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [-1, -2]) - + + def assign_elementwise_index(self, node, idx): + self.assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) not in (int, float): + self.inherit_computation(node_in, node) + + def assign_softmax_index(self, node, idx): + self.assign_index_as_input(node, idx) + self.mark_computation(node, idx, [node.kwargs['dim']]) + + def assign_view_reshape_index(self, node, node_idx): + # get data, turn into number + origin_node = node.args[0] + origin_shape = origin_node.meta['tensor_meta'].shape + target_shape = [] + for i in range(1, len(node.args)): + if isinstance(node.args[i], int): + target_shape.append(node.args[i]) + else: + target_shape.append(node.args[i].meta['fwd_out'][0]) + + # compute the value of -1 + if -1 in target_shape: + origin_product = 1 + for i in origin_shape: + origin_product *= i + target_product = -1 + for i in target_shape: + target_product *= i + shape_idx = target_shape.index(-1) + target_shape[shape_idx] = origin_product // target_product + + # determine changed dim + len_diff = len(origin_shape) - len(target_shape) + if len_diff == 1: + # dim merge + dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] + dim_to = [dim_equal.index(False)] + dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + elif len_diff == -1: + # dim expand + dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] + dim_from = [dim_equal.index(False)] + dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + else: + raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") + + # get new index + origin_trace = self.find_idx_trace_from_node(origin_node) + new_trace = copy.deepcopy(origin_trace) + dim_from.reverse() + for i in dim_from: + new_trace.pop(i) + for i in dim_to: + new_trace.insert(i, self.add_index()) + self.idx_trace_list[node_idx]['idx'] = new_trace + + # inherit computation + self.inherit_computation(origin_node, node) + compute_log = self.find_compute_trace_from_node(origin_node) + for i in dim_from: + if origin_trace[i] in compute_log: + for j in dim_to: + self.mark_computation(node, node_idx, [j]) + break + + # log view + view_dict = {"idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [new_trace[i] for i in dim_to], + "dim_to": dim_to} + self.idx_view_list.append(view_dict) + def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': @@ -111,15 +217,21 @@ class NodeIndexTracer(object): elif node.op == 'call_method': if 'transpose' in node.name: self.assign_transpose_index(node, idx) - elif 'view' in node.name: - pass elif 'permute' in node.name: - pass + self.assign_permute_index(node, idx) + elif 'view' in node.name or 'reshape' in node.name: + self.assign_view_reshape_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': if 'linear' in node.name: self.assign_linear_index(node, idx) + elif 'matmul' in node.name: + self.assign_matmul_index(node, idx) + elif 'softmax' in node.name: + self.assign_softmax_index(node, idx) + elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): + self.assign_elementwise_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -127,12 +239,14 @@ class NodeIndexTracer(object): else: raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == 'call_module': - if 'layernorm' in node.name: + if any(n in node.name for n in ['layernorm', 'norm']): self.assign_layernorm_index(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == 'get_attr': self.assign_all_index(node, idx) # get param + elif node.op == 'output': + continue else: raise NotImplementedError(node.op, "op not implemented yet!") @@ -297,6 +411,7 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod # node is an operation, calculate tmp, output node and delete node memory else: # forward memory + # TODO: permute will create a tmp copy if not contiguous act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory