diff --git a/chunk_codegen.py b/chunk_codegen.py index 79cefddf0..18d9a0c8d 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -896,23 +896,22 @@ class IndexTracer(object): def _find_inherit_dim(self, input_node, input_dim, node): input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) - node_idx = _find_idx_by_name(node.name, self.nodes_list) node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] and node_trace_source[node_dim][input_node_idx] == input_dim ): - return {node_idx: node_dim} - return {} + return node_dim + return None def check_index_duplicate(self, chunk_infos): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - input_dim_after_node.update( - self._find_inherit_dim(input_node, v, self.nodes_list[k]) - ) + inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim for node in self.nodes_list[ chunk_infos["region"][0] : chunk_infos["region"][1] + 1 @@ -934,8 +933,8 @@ class IndexTracer(object): class MemoryEstimator(object): - def __init__(self) -> None: - pass + def __init__(self, index_tracer: IndexTracer) -> None: + self.index_tracer = index_tracer def _get_meta_node_size(self, x): x = x.meta["tensor_meta"] @@ -950,6 +949,8 @@ class MemoryEstimator(object): } out_size = activation_size(fwd_out) out_node = [n.name] if out_size > 0 else [] + # if any(i in n.name for i in ['transpose', 'permute', 'view']): + # out_size = 0 return out_size, out_node def _get_output_node_size(self, n): @@ -961,11 +962,19 @@ class MemoryEstimator(object): if i not in active_list: active_list.append(i) - def _get_delete_node(self, user, user_to_last_uses): + def _get_delete_node(self, user, user_to_last_uses, to_keep=None): delete_size = 0 delete_node = [] if user.op not in ("placeholder", "output"): nodes_to_delete = user_to_last_uses.get(user, []) + if to_keep is not None: + keep_list = [] + for n in nodes_to_delete: + if n.name in to_keep: + keep_list.append(n) + for n in keep_list: + if n in nodes_to_delete: + nodes_to_delete.remove(n) if len(nodes_to_delete): out_node = [self._get_output_node(i) for i in nodes_to_delete] delete_size = sum([i[0] for i in out_node]) @@ -974,15 +983,30 @@ class MemoryEstimator(object): delete_node.append(out_node[i][1][0]) elif nodes_to_delete[i].op == "placeholder": delete_node.append(nodes_to_delete[i].name) + # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): + # delete_node.append(nodes_to_delete[i].name) return delete_size, delete_node - def _get_delete_node_size(self, user, user_to_last_uses): - return self._get_delete_node(user, user_to_last_uses)[0] + def _get_delete_node_size(self, user, user_to_last_uses, to_keep): + return self._get_delete_node(user, user_to_last_uses, to_keep)[0] def _remove_deactive_node(self, user, user_to_last_uses, active_list): delete_node = self._get_delete_node(user, user_to_last_uses)[1] for i in delete_node: - active_list.remove(i) + if i in active_list: + active_list.remove(i) + + def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): + nodes_to_delete = [] + for chunk_input in chunk_inputs + chunk_inputs_non_chunk: + chunk_input_users = chunk_input.users.keys() + chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users] + if all(i <= chunk_end_idx for i in chunk_input_users_idx): + if chunk_input not in nodes_to_delete: + nodes_to_delete.append(chunk_input) + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + return delete_size def _get_last_usr(self, nodes): node_to_last_use: Dict[Node, Node] = {} @@ -1000,7 +1024,8 @@ class MemoryEstimator(object): def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): mem = 0 - not_contiguous_ops = ["transpose", "permute"] + not_contiguous_ops = ["permute"] + inherit_contiguous_ops = ["transpose", "view"] if node.op == "call_function" and any( n in node.name for n in ["matmul", "reshape"] @@ -1020,30 +1045,36 @@ class MemoryEstimator(object): ): if node not in not_contiguous_list: not_contiguous_list.append(node) - elif any(i in node.args for i in not_contiguous_list): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - return mem - def _get_chunk_ratio(self, node, chunk_dim, chunk_size): - sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0]) - dim = list(sorted_dim[-1].values())[0] - shape = node.meta["tensor_meta"].shape - chunk_ratio = float(chunk_size) / shape[dim] - return chunk_ratio + def _get_chunk_ratio(self, node, chunk_inputs, chunk_inputs_dim, chunk_size): + node_shape = _get_node_shape(node) + node_source = self.index_tracer._find_source_trace_from_node(node) + for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): + for k, v in input_node_dim.items(): + inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) + if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): + chunk_ratio = float(chunk_size) / node_shape[inherit_dim] + return chunk_ratio + for dim, source in enumerate(node_source): + if k in source and source[k] == inherit_dim: + chunk_ratio = float(chunk_size) / node_shape[dim] + return chunk_ratio + return 1. def _get_chunk_delete_node_size( - self, user, user_to_last_uses, chunk_ratio, node_list, start_node, end_node + self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names ): + # if any(j in user.name for j in ['transpose', 'permute', 'view']): + # return 0 if user.op in ("placeholder", "output"): return 0 nodes_to_delete = user_to_last_uses.get(user, []) delete_size = 0 for n in nodes_to_delete: - node_idx = _find_idx_by_name(n.name, node_list) - if start_node <= node_idx < end_node: - delete_size += self._get_output_node_size(n) * chunk_ratio + if n.name in chunk_inputs_names: + continue + delete_size += self._get_output_node_size(n) * chunk_ratio return delete_size def _print_mem_log(self, log, nodes, title=None): @@ -1071,10 +1102,7 @@ class MemoryEstimator(object): def estimate_chunk_inference_mem( self, gm: torch.fx.GraphModule, - start_nodes=None, - end_nodes=None, - chunk_dims=None, - chunk_sizes=None, + chunk_infos=None, ): act_memory = 0.0 act_memory_peak_log = [] @@ -1087,36 +1115,53 @@ class MemoryEstimator(object): user_to_last_uses_no_free_var = self._get_last_usr(node_list) _delete_free_var_from_last_use(user_to_last_uses_no_free_var) - use_chunk = all( - i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes] - ) + use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem + chunk_size = 1 + chunk_inputs_names = [] + + if use_chunk: + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + chunk_outputs = [i["outputs"][0] for i in chunk_infos] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if use_chunk and idx in start_nodes: + if use_chunk and idx in chunk_starts: chunk_within = True - chunk_region_idx = start_nodes.index(idx) + chunk_region_idx = chunk_starts.index(idx) + act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) + + # determine chunk ratio for current node + if chunk_within: chunk_ratio = self._get_chunk_ratio( - node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx] + node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size ) - act_memory += self._get_output_node_size( - node_list[end_nodes[chunk_region_idx]] - ) / (1024**2) # if node is placeholder, just add the size of the node if node.op == "placeholder": act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) act_memory_peak_log.append(act_memory) - active_node_list.append(node.name) # skip output elif node.op == "output": continue - # node is an operation, calculate tmp, output node and delete node memory + # no change for non compute node + elif _is_non_compute_node_except_placeholder(node): + act_memory_peak_log.append(act_memory) + # node is a compute op + # calculate tmp, output node and delete node memory else: # forward memory + # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose act_memory += ( self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio @@ -1133,29 +1178,35 @@ class MemoryEstimator(object): * chunk_ratio / (1024**2) ) + # delete unused vars not in chunk_input_list + # we can't delete input nodes until chunk ends if chunk_within: act_memory -= self._get_chunk_delete_node_size( node, user_to_last_uses_no_free_var, chunk_ratio, - node_list, - start_nodes[chunk_region_idx], - end_nodes[chunk_region_idx], + chunk_inputs_names ) / (1024**2) else: - act_memory -= self._get_delete_node_size( - node, user_to_last_uses_no_free_var - ) / (1024**2) + act_memory -= (self._get_delete_node_size( + node, user_to_last_uses_no_free_var, chunk_inputs_names + ) / (1024**2)) - # log active node + # log active node, only effective without chunk self._add_active_node(node, active_node_list) self._remove_deactive_node(node, user_to_last_uses, active_node_list) # if node in chunk end nodes, restore chunk settings - if use_chunk and idx in end_nodes: + if use_chunk and idx in chunk_ends: act_memory -= ( self._get_output_node_size(node) * chunk_ratio / (1024**2) ) + act_memory -= self._get_chunk_inputs_size( + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], + node_list, + chunk_regions[chunk_region_idx][1] + ) / (1024**2) chunk_within = False chunk_ratio = 1 chunk_region_idx = None @@ -1178,11 +1229,11 @@ class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.node_list = list(gm.graph.nodes) - self.memory_estimator = MemoryEstimator() self.index_tracer = IndexTracer(gm) self.index_tracer.trace_index() self.flow_tracer = FlowTracer(gm) self.flow_tracer.trace_flow() + self.memory_estimator = MemoryEstimator(self.index_tracer) def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1210,7 +1261,7 @@ class ChunkRegionSearch(object): min_var = self._get_min_free_var(active_node, free_vars) # from peak_node to free_var - chunk_region_start = None + chunk_region_start = len(free_vars) for i in range(peak_node, -1, -1): if len(active_node[i]) == min_var: chunk_region_start = i + 1 @@ -1218,7 +1269,7 @@ class ChunkRegionSearch(object): if i in free_vars or i == 0: raise RuntimeError() # from peak_node to len-2 - chunk_region_end = None + chunk_region_end = len(active_node) - 1 for i in range(peak_node, len(active_node)): if len(active_node[i]) == min_var: chunk_region_end = i @@ -1352,7 +1403,7 @@ class ChunkRegionSearch(object): return False def search_region(self): - chunk_regions = [] + chunk_infos = [] ( init_mem_peak, _, @@ -1361,25 +1412,19 @@ class ChunkRegionSearch(object): mem_peak = init_mem_peak while True: - chunk_region = self._step_search(mem_peak, active_node, chunk_regions) - if chunk_region is None: + chunk_info = self._step_search(mem_peak, active_node, chunk_infos) + if chunk_info is None: break - chunk_regions.append(chunk_region) + chunk_infos.append(chunk_info) ( mem_peak, _, active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.gm, - [i["region"][0] for i in chunk_regions], - [i["region"][1] for i in chunk_regions], - [i["inputs_dim"] for i in chunk_regions], - [1] * len(chunk_regions), - ) + ) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos) if self._stop_search(init_mem_peak, mem_peak): break - return chunk_regions + return chunk_infos def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -1415,7 +1460,7 @@ def _gen_loop_end( chunk_slice = _gen_chunk_slice_dim( chunk_outputs_dim, "chunk_idx", chunk_output_shape ) - context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) + context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name) context += ( chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" )