diff --git a/chunk_codegen.py b/chunk_codegen.py index 9930a0570..c1d9e26e7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -10,6 +10,13 @@ CODEGEN_AVAILABLE = True __all__ = ['ChunkCodeGen'] +def _delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == 'placeholder': + user_to_last_uses[key].remove(n) + + class NodeIndexTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -19,7 +26,7 @@ class NodeIndexTracer(object): self.idx_view_list = [] self.idx_count = -1 - def add_index(self): + def _add_index(self): """ Update the count and return it. To record the idx number. @@ -29,7 +36,7 @@ class NodeIndexTracer(object): self.idx_count += 1 return self.idx_count - def inherit_computation(self, node_from, node_to): + def _inherit_computation(self, node_from, node_to): """ Inherit computed dim from node_from to node_to. If a dim in node_from is marked as computed and exists in node_to, @@ -39,13 +46,13 @@ class NodeIndexTracer(object): node_from (node): node to be inherited node_to (node): new node to inherit """ - _, compute_from = self.find_trace_from_node(node_from) - idx_to, compute_to = self.find_trace_from_node(node_to) + _, compute_from = self._find_trace_from_node(node_from) + idx_to, compute_to = self._find_trace_from_node(node_to) for i in compute_from: if i in idx_to and i not in compute_to: compute_to.append(i) - def mark_idx_equal(self, idx1, idx2): + def _mark_idx_equal(self, idx1, idx2): """ Mark 2 index to be equal. @@ -55,7 +62,7 @@ class NodeIndexTracer(object): """ self.idx_trace_equal.append((idx1, idx2)) - def mark_computation(self, node, idx, dim): + def _mark_computation(self, node, idx, dim): """ Mark some dims of node as computed. @@ -64,7 +71,7 @@ class NodeIndexTracer(object): idx (int): node index dim (list or int): dims to be marked as computed """ - input_node_idx_trace = self.find_idx_trace_from_node(node) + input_node_idx_trace = self._find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] for d in dim: @@ -72,7 +79,7 @@ class NodeIndexTracer(object): if cur_idx not in self.idx_trace_list[idx]['compute']: self.idx_trace_list[idx]['compute'].append(cur_idx) - def find_trace_from_node(self, node): + def _find_trace_from_node(self, node): """ Find node idx and compute trace by the node. @@ -86,7 +93,7 @@ class NodeIndexTracer(object): node_dict = self.idx_trace_list[node_idx] return node_dict['idx'], node_dict['compute'] - def find_idx_trace_from_node(self, node): + def _find_idx_trace_from_node(self, node): """ Find node idx trace by the node. @@ -98,7 +105,7 @@ class NodeIndexTracer(object): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['idx'] - def find_compute_trace_from_node(self, node): + def _find_compute_trace_from_node(self, node): """ Find node compute trace by the node. @@ -110,7 +117,7 @@ class NodeIndexTracer(object): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute'] - def assign_index_as_input(self, node, node_idx): + def _assign_index_as_input(self, node, node_idx): """ Assign node's trace as its input node. @@ -124,7 +131,7 @@ class NodeIndexTracer(object): new_idx_trace = copy.deepcopy(input_node_idx_trace) self.idx_trace_list[node_idx]['idx'] = new_idx_trace - def assign_all_index(self, node, node_idx): + def _assign_all_index(self, node, node_idx): """ Add new index for all node's dims. @@ -135,10 +142,10 @@ class NodeIndexTracer(object): shape = node.meta['tensor_meta'].shape new_trace = [] for _ in shape: - new_trace.append(self.add_index()) + new_trace.append(self._add_index()) self.idx_trace_list[node_idx]['idx'] = new_trace - def assign_transpose_index(self, node, node_idx): + def _assign_transpose_index(self, node, node_idx): """ Assign index for transpose op. 1. swap input's dim according to transpose args @@ -149,16 +156,16 @@ class NodeIndexTracer(object): node_idx (int) """ tranpose_dim = node.args[1:] - input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(node.args[0], node) + self._inherit_computation(node.args[0], node) - def assign_permute_index(self, node, node_idx): + def _assign_permute_index(self, node, node_idx): """ Assign index for permute op. 1. swap input's dim according to permute args @@ -169,16 +176,16 @@ class NodeIndexTracer(object): node_idx (int) """ permute_dim = node.args[1:] - input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) + input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) new_idx_trace = copy.deepcopy(input_node_idx_trace) for idx, d in enumerate(permute_dim): new_idx_trace[idx] = input_node_idx_trace[d] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(node.args[0], node) + self._inherit_computation(node.args[0], node) - def assign_linear_index(self, node, node_idx): + def _assign_linear_index(self, node, node_idx): """ Assign index for linear op. 1. copy trace from input node and change last index accroding to weight @@ -190,22 +197,22 @@ class NodeIndexTracer(object): node_idx (int) """ input_node, weight, bias = node.args - input_node_idx_trace = self.find_idx_trace_from_node(input_node) - weight_idx_trace = self.find_idx_trace_from_node(weight) + input_node_idx_trace = self._find_idx_trace_from_node(input_node) + weight_idx_trace = self._find_idx_trace_from_node(weight) new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace[-1] = weight_idx_trace[1] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(input_node, node) - self.mark_computation(node, node_idx, [-1]) - self.mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) + self._inherit_computation(input_node, node) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) if bias: - bias_idx_trace = self.find_idx_trace_from_node(bias) - self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + bias_idx_trace = self._find_idx_trace_from_node(bias) + self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) - def assign_matmul_index(self, node, node_idx): + def _assign_matmul_index(self, node, node_idx): """ Assign index for matmul op. 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) @@ -217,20 +224,20 @@ class NodeIndexTracer(object): node_idx (int) """ matmul_left, matmul_right = node.args - matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) - matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) + matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left) + matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right) assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) new_idx_trace = copy.deepcopy(matmul_left_idx_trace) new_idx_trace[-1] = matmul_right_idx_trace[-1] self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self.inherit_computation(matmul_left, node) - self.inherit_computation(matmul_right, node) - self.mark_computation(node, node_idx, [-1]) - self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + self._inherit_computation(matmul_left, node) + self._inherit_computation(matmul_right, node) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) - def assign_layernorm_index(self, node, idx): + def _assign_layernorm_index(self, node, idx): """ Assign index for layernorm op. 1. assign index as input node @@ -240,11 +247,11 @@ class NodeIndexTracer(object): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) - self.inherit_computation(node.args[0], node) - self.mark_computation(node, idx, [-1, -2]) + self._assign_index_as_input(node, idx) + self._inherit_computation(node.args[0], node) + self._mark_computation(node, idx, [-1, -2]) - def assign_elementwise_index(self, node, idx): + def _assign_elementwise_index(self, node, idx): """ Assign index for element-wise op (eg. relu sigmoid add mul). 1. assign index as input node @@ -254,12 +261,12 @@ class NodeIndexTracer(object): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) + self._assign_index_as_input(node, idx) for node_in in node.args: if type(node_in) not in (int, float): - self.inherit_computation(node_in, node) + self._inherit_computation(node_in, node) - def assign_softmax_index(self, node, idx): + def _assign_softmax_index(self, node, idx): """ Assign index for softmax op. 1. assign index as input node @@ -269,11 +276,11 @@ class NodeIndexTracer(object): node (node) node_idx (int) """ - self.assign_index_as_input(node, idx) - self.inherit_computation(node.args[0], node) - self.mark_computation(node, idx, [node.kwargs['dim']]) + self._assign_index_as_input(node, idx) + self._inherit_computation(node.args[0], node) + self._mark_computation(node, idx, [node.kwargs['dim']]) - def assign_view_reshape_index(self, node, node_idx): + def _assign_view_reshape_index(self, node, node_idx): """ Assign index for view and reshape op. 1. get origin shape and target shape by meta info. @@ -325,22 +332,22 @@ class NodeIndexTracer(object): raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") # get new index - origin_trace = self.find_idx_trace_from_node(origin_node) + origin_trace = self._find_idx_trace_from_node(origin_node) new_trace = copy.deepcopy(origin_trace) dim_from.reverse() for i in dim_from: new_trace.pop(i) for i in dim_to: - new_trace.insert(i, self.add_index()) + new_trace.insert(i, self._add_index()) self.idx_trace_list[node_idx]['idx'] = new_trace # inherit computation - self.inherit_computation(origin_node, node) - compute_log = self.find_compute_trace_from_node(origin_node) + self._inherit_computation(origin_node, node) + compute_log = self._find_compute_trace_from_node(origin_node) for i in dim_from: if origin_trace[i] in compute_log: for j in dim_to: - self.mark_computation(node, node_idx, [j]) + self._mark_computation(node, node_idx, [j]) break # log view, not used now @@ -353,25 +360,25 @@ class NodeIndexTracer(object): def trace_node_idx(self): for idx, node in enumerate(self.nodes_list): if node.op == 'placeholder': - self.assign_all_index(node, idx) + self._assign_all_index(node, idx) elif node.op == 'call_method': if 'transpose' in node.name: - self.assign_transpose_index(node, idx) + self._assign_transpose_index(node, idx) elif 'permute' in node.name: - self.assign_permute_index(node, idx) + self._assign_permute_index(node, idx) elif 'view' in node.name or 'reshape' in node.name: - self.assign_view_reshape_index(node, idx) + self._assign_view_reshape_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': if 'linear' in node.name: - self.assign_linear_index(node, idx) + self._assign_linear_index(node, idx) elif 'matmul' in node.name: - self.assign_matmul_index(node, idx) + self._assign_matmul_index(node, idx) elif 'softmax' in node.name: - self.assign_softmax_index(node, idx) + self._assign_softmax_index(node, idx) elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): - self.assign_elementwise_index(node, idx) + self._assign_elementwise_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -380,206 +387,198 @@ class NodeIndexTracer(object): raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == 'call_module': if any(n in node.name for n in ['layernorm', 'norm']): - self.assign_layernorm_index(node, idx) + self._assign_layernorm_index(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == 'get_attr': - self.assign_all_index(node, idx) # get param + self._assign_all_index(node, idx) # get param elif node.op == 'output': continue else: raise NotImplementedError(node.op, "op not implemented yet!") -def _get_meta_node_size(x): - x = x.meta['tensor_meta'] - x = x.numel * torch.tensor([], dtype=x.dtype).element_size() - return x +class MemoryEstimator(object): + def __init__(self) -> None: + pass + def _get_meta_node_size(self, x): + x = x.meta['tensor_meta'] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x -def _get_output_node_size(n): - fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} - return activation_size(fwd_out) + def _get_output_node_size(self, n): + fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} + return activation_size(fwd_out) - -def _get_delete_node_size(user, user_to_last_uses): - if user.op in ('placeholder', 'output'): + def _get_delete_node_size(self, user, user_to_last_uses): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete]) + return delete_size return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete]) - return delete_size - return 0 + def _get_last_usr(self, nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} -def _get_last_usr(nodes): - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - return user_to_last_uses + def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ['transpose', 'permute'] + if node.op == 'call_function' and 'matmul' in node.name: + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += self._get_output_node_size(n) + elif node.op == 'call_module': + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + elif any(i in node.args for i in not_contiguous_list): + if node not in not_contiguous_list: + not_contiguous_list.append(node) -def _delete_free_var_from_last_use(user_to_last_uses): - for key, value in user_to_last_uses.items(): - for n in value: - if n.op == 'placeholder': - user_to_last_uses[key].remove(n) + return mem - -def _get_contiguous_memory(node, not_contiguous_list, delete=False): - mem = 0 - not_contiguous_ops = ['transpose', 'permute'] - - if node.op == 'call_function' and 'matmul' in node.name: - for n in node.args: - if n in not_contiguous_list: - # matmul won't change origin tensor, but create a tmp copy - mem += _get_output_node_size(n) - elif node.op == 'call_module': - for n in node.args: - if n in not_contiguous_list: - # module will just make origin tensor to contiguous - if delete: - not_contiguous_list.remove(n) - elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - elif any(i in node.args for i in not_contiguous_list): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - - return mem - - -def _estimate_inference_mem(gm: torch.fx.GraphModule): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - for node in gm.graph.nodes: - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) / (1024 ** 2) - act_memory_peak_log.append(act_memory) - act_memory_after_node_log.append(act_memory) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) - act_memory += _get_output_node_size(node) / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - act_memory_after_node_log.append(act_memory) - - print("no chunk") - _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") - _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") - - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory - - -def _get_chunk_ratio(node, chunk_dim, chunk_size): - shape = node.meta['tensor_meta'].shape - chunk_ratio = float(chunk_size) / shape[chunk_dim] - return chunk_ratio - - -def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): - if user.op in ('placeholder', 'output'): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - delete_size = 0 - for n in nodes_to_delete: - node_idx = _find_idx_by_name(n.name, node_list) - if start_node <= node_idx < end_node: - delete_size += _get_output_node_size(n) * chunk_ratio - return delete_size - - -def _print_mem_log(log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end='') - if (idx + 1) % 3 == 0: - print("") - print("\n") - - -def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - not_contiguous_list = [] - user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) - _delete_free_var_from_last_use(user_to_last_uses) - within_chunk = False - region_idx = 0 - chunk_ratio = 1 # use it to estimate chunk mem - node_list = list(gm.graph.nodes) - - for idx, node in enumerate(node_list): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if idx in start_nodes: - within_chunk = True - chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) - - # if node is placeholder, just add the size of the node - if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2) - act_memory_peak_log.append(act_memory) - # skip output - elif node.op == 'output': - continue - # node is an operation, calculate tmp, output node and delete node memory - else: - # forward memory - # TODO: permute will create a tmp copy if not contiguous - act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) - act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) - if within_chunk: - act_memory -= _get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, - start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + def estimate_inference_mem(self, gm: torch.fx.GraphModule): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) + for node in gm.graph.nodes: + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += self._get_meta_node_size(node) / (1024 ** 2) + act_memory_peak_log.append(act_memory) + act_memory_after_node_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory else: - act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) - - if idx in end_nodes: - act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2) - within_chunk = False - chunk_ratio = 1 - region_idx += 1 + # forward memory + act_memory += self._get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) + act_memory += self._get_output_node_size(node) / (1024 ** 2) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) + act_memory_after_node_log.append(act_memory) + + print("no chunk") + self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") + self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") - act_memory_after_node_log.append(act_memory) + param_memory = parameter_size(gm) + return act_memory + param_memory, param_memory - print("chunk") - _print_mem_log(act_memory_peak_log, node_list, "peak") - _print_mem_log(act_memory_after_node_log, node_list, "after") - param_memory = parameter_size(gm) - return act_memory + param_memory, param_memory + def _get_chunk_ratio(self, node, chunk_dim, chunk_size): + shape = node.meta['tensor_meta'].shape + chunk_ratio = float(chunk_size) / shape[chunk_dim] + return chunk_ratio + + + def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node): + if user.op in ('placeholder', 'output'): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + node_idx = _find_idx_by_name(n.name, node_list) + if start_node <= node_idx < end_node: + delete_size += self._get_output_node_size(n) * chunk_ratio + return delete_size + + + def _print_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") + + + def estimate_chunk_inference_mem(self, gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) + within_chunk = False + region_idx = 0 + chunk_ratio = 1 # use it to estimate chunk mem + node_list = list(gm.graph.nodes) + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if idx in start_nodes: + within_chunk = True + chunk_ratio = self._get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) + act_memory += self._get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) + + # if node is placeholder, just add the size of the node + if node.op == 'placeholder': + act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024 ** 2) + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == 'output': + continue + # node is an operation, calculate tmp, output node and delete node memory + else: + # forward memory + # TODO: permute will create a tmp copy if not contiguous + act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) + act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) + if within_chunk: + act_memory -= self._get_chunk_delete_node_size( + node, user_to_last_uses, chunk_ratio, node_list, + start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) + else: + act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + + if idx in end_nodes: + act_memory -= self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) + within_chunk = False + chunk_ratio = 1 + region_idx += 1 + + act_memory_after_node_log.append(act_memory) + + print("chunk") + self._print_mem_log(act_memory_peak_log, node_list, "peak") + self._print_mem_log(act_memory_after_node_log, node_list, "after") + + param_memory = parameter_size(gm) + return act_memory + param_memory, param_memory def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -695,8 +694,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) - _estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) - _estimate_inference_mem(meta_graph) + memory_estimator = MemoryEstimator() + memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) + memory_estimator.estimate_inference_mem(meta_graph) node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx()