diff --git a/chunk_codegen.py b/chunk_codegen.py index 9147aa9fc..191eab564 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -134,7 +134,7 @@ class FlowTracer(object): return name, i raise RuntimeError("invalid node") - def get_flow_mix(self, node): + def _get_flow_mix_node(self, node): if self._is_non_compute_node(node): return None _, node_trace = self.find_node_flow(node) @@ -145,7 +145,7 @@ class FlowTracer(object): vars = list(node_trace["outside_depend"][0].values())[0] return vars - def get_same_flow_node(self, node_list, node): + def _get_same_flow_node(self, node_list, node): name, _ = self.find_node_flow(node) result = [] for i in self.flow_trace[name]: @@ -181,13 +181,14 @@ class FlowTracer(object): ) return self.flow_trace - def _detect_flow(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = _find_chunk_input_and_output_nodes( + def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) chunk_info = { "region": (start_idx, end_idx), "inputs": inputs, + "inputs_non_chunk": [], "inputs_dim": start_dim, "outputs": outputs, "outputs_dim": end_dim, @@ -197,31 +198,71 @@ class FlowTracer(object): for idx in range(start_idx, end_idx + 1): node = self.node_list[idx] - mix_flow_var = self.get_flow_mix(node) - if mix_flow_var is None: + mix_flow_node = self._get_flow_mix_node(node) + if mix_flow_node is None: continue - # if there is a flow mix, op must be in [mul, add, div, matmul] + # if there is a flow mix, op must be in [mul, add, matmul] # element-wise op requires dim to be equal in every dim if any(n in node.name for n in ["mul", "add"]): for i in node.args: - if type(i) == type(mix_flow_var) and i != mix_flow_var: + if type(i) == type(mix_flow_node) and i != mix_flow_node: main_flow_var = i # if mix flow is a broadcast in chunk dim, # TODO need to move that flow out of the chunk - if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1: + mix_flow_node_dim = index_tracer._get_node_chunk_dim( + self.node_list[end_idx], end_dim, node + ) + if mix_flow_node_dim is None: flow_flag = True - for i in self.get_same_flow_node( - chunk_info["inputs"], mix_flow_var + break + if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + flow_flag = False + for i in self._get_same_flow_node( + chunk_info["inputs"], mix_flow_node ): chunk_info["inputs"].remove(i) # else, we need to chunk mix var as well else: # TODO chunk another value - flow_flag = False + flow_flag = True break else: raise NotImplementedError("%s not implemented" % node.name) + + inputs_dim = [] + remove_inputs = [] + for input_node in chunk_info['inputs']: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + dim = None + if start_dim <= user_idx < end_idx: + dim = index_tracer._get_node_chunk_dim( + self.node_list[end_idx], end_dim, input_node + ) + elif user_idx == end_idx: + dim = end_dim + # n has relation with chunk dim + if dim is not None and _get_node_shape(user)[dim] != 1: + input_dict[user_idx] = dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + chunk_info['inputs_dim'] = inputs_dim + for i in remove_inputs: + if i in chunk_info['inputs']: + chunk_info['inputs'].remove(i) + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1]) + for i in non_chunk_inputs: + if i not in chunk_info['inputs']: + chunk_info["inputs_non_chunk"].append(i) + return flow_flag, chunk_info @@ -367,6 +408,20 @@ class IndexTracer(object): node_dict = self.idx_trace_list[node_idx] return node_dict + def _find_source_trace_from_node(self, node): + """ + Find node source trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = _find_idx_by_name(node.name, self.nodes_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict["source"] + def _find_idx_trace_from_node(self, node): """ Find node idx trace by the node. @@ -836,6 +891,15 @@ class IndexTracer(object): # return False # return True + def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + class MemoryEstimator(object): def __init__(self) -> None: @@ -931,8 +995,10 @@ class MemoryEstimator(object): return mem def _get_chunk_ratio(self, node, chunk_dim, chunk_size): + sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0]) + dim = list(sorted_dim[-1].values())[0] shape = node.meta["tensor_meta"].shape - chunk_ratio = float(chunk_size) / shape[chunk_dim] + chunk_ratio = float(chunk_size) / shape[dim] return chunk_ratio def _get_chunk_delete_node_size( @@ -1157,6 +1223,8 @@ class ChunkRegionSearch(object): return chunk_infos def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + if start_idx == 71 and end_idx == 126: + print(1) start_traces = input_trace[start_idx] end_trace = output_trace[end_idx] end_node = self.node_list[end_idx] @@ -1188,7 +1256,7 @@ class ChunkRegionSearch(object): continue # detect flow meet flow_flag, chunk_info = self.flow_tracer._detect_flow( - start_idx, start_dim, end_idx, end_dim + start_idx, start_dim, end_idx, end_dim, self.index_tracer ) if flow_flag: continue @@ -1301,56 +1369,53 @@ def _get_first_non_single_dim(shape): raise RuntimeError("can not get first non single dim for shape", shape) -def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2): - if len(chunk_input_meta) == 1: - node = chunk_input_meta[0] - node_shape = node.meta["tensor_meta"].shape - free_shape = [ - node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) - ] - chunk_dim = _get_first_non_single_dim(free_shape) - chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) - out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) +def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): + input_node = chunk_input[0] + + out_shape = _get_node_shape(chunk_output) + out_str = str(list(out_shape)) - context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" - % (out_shape, node.name, node.name, chunk_size) - ) - context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) - context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) - else: - raise NotImplementedError( - "input with size %d not implemented" % len(chunk_input_meta) - ) + context = ( + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" + % (out_str, input_node.name, input_node.name, chunk_size) + ) + context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) + + # node = chunk_input[0] + # node_shape = node.meta["tensor_meta"].shape + # free_shape = [ + # node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) + # ] + # chunk_dim = _get_first_non_single_dim(free_shape) + # chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) + # out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) + + # context = ( + # "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" + # % (out_shape, node.name, node.name, chunk_size) + # ) + # context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) + # context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) return context -def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim): - chunk_inputs_name = chunk_inputs[0].name +def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list): chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - free_shape = [ - chunk_output_shape[i] if i in chunk_dim else 1 - for i in range(len(chunk_output_shape)) - ] - chunk_dim = _get_first_non_single_dim(free_shape) - chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) - - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) + context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") # determine if its the last use for chunk input - users_name = list(chunk_inputs[0].users.keys()) - if all( - [ - _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in users_name - ] - ): - context += "; %s = None" % chunk_inputs_name + for chunk_input in (chunk_inputs + chunk_non_compute_inputs): + if all( + [ + _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx + for user in chunk_input.users.keys() + ] + ): + context += "; %s = None" % chunk_input.name context += "\n" return context @@ -1382,7 +1447,24 @@ def _find_input_and_output_nodes(nodes: List[Node]): return input_nodes, output_nodes -def _find_chunk_input_and_output_nodes(nodes: List[Node]): +def _find_chunk_all_input_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + for node in nodes: + for input_node in node._input_nodes.keys(): + if ( + input_node not in nodes + and input_node not in input_nodes + ): + input_nodes.append(input_node) + return input_nodes + + +def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): """ Find non-compute input and output node names. input nodes are nodes used in the list @@ -1410,7 +1492,7 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]): if ( output_node not in nodes and node not in output_nodes - and not _is_non_compute_node_except_placeholder(input_node) + and not _is_non_compute_node_except_placeholder(output_node) ): output_nodes.append(node) @@ -1454,44 +1536,34 @@ def emit_code_with_chunk( emit_node_func: function to emit node delete_unused_value_func: function to remove the unused value """ - - # find the offload regions - chunk_region_search = ChunkRegionSearch(meta_graph) - chunk_search = chunk_region_search.search_region() - chunk_regions = [i["region"] for i in chunk_search] - chunk_dims = [i["dim"] for i in chunk_search] - chunk_infos = [i["chunk_info"] for i in chunk_search] - - chunk_starts = [item[0] for item in chunk_regions] - chunk_ends = [item[1] for item in chunk_regions] - chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos] - chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos] - within_chunk_region = False - node_list = list(nodes) - # find the input and output var names for each offload region - # for idx, (start, end) in enumerate(chunk_regions): - # offload_node_list = node_list[start:end + 1] - # inputs, outputs = _find_input_and_output_nodes(offload_node_list) - # chunk_inputs.append(inputs) - # chunk_outputs.append(outputs) + # find the chunk regions + chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_search = chunk_region_search.search_region() + chunk_regions = [i["region"] for i in chunk_search] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + + chunk_inputs = [i["inputs"] for i in chunk_search] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] chunk_inputs_idx = [ [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs ] - chunk_outputs_idx = [ - [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs - ] - chunk_inputs_names = [] - for i in chunk_inputs: - for j in i: - chunk_inputs_names.append(j.name) + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] + + chunk_outputs = [i["outputs"][0] for i in chunk_search] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] + chunk_outputs_idx = [ + _find_idx_by_name(i.name, node_list) for i in chunk_outputs + ] - # this flag is to prevent repeated insert of save tensors - # hooks definition in ckpt_func node_idx = 0 region_idx = 0 + within_chunk_region = False + while node_idx < len(node_list): node = node_list[node_idx] @@ -1500,21 +1572,24 @@ def emit_code_with_chunk( region_idx = chunk_starts.index(node_idx) # add for loop - chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] body.append( _gen_loop_start( - chunk_input_meta, - node_list[chunk_ends[region_idx]], - chunk_dims[region_idx], + chunk_inputs[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], ) ) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body[-1] = _replace_name( - body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor" - ) + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node)) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) @@ -1526,7 +1601,10 @@ def emit_code_with_chunk( if node_idx in chunk_ends: body.append( _gen_loop_end( - node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx] + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], node_list ) ) within_chunk_region = False