diff --git a/chunk_codegen.py b/chunk_codegen.py index 191eab564..3bea84fae 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -229,10 +229,10 @@ class FlowTracer(object): break else: raise NotImplementedError("%s not implemented" % node.name) - + inputs_dim = [] remove_inputs = [] - for input_node in chunk_info['inputs']: + for input_node in chunk_info["inputs"]: input_dict = {} for user in input_node.users.keys(): if _is_non_compute_node(user): @@ -252,15 +252,17 @@ class FlowTracer(object): remove_inputs.append(input_node) else: inputs_dim.append(input_dict) - chunk_info['inputs_dim'] = inputs_dim + chunk_info["inputs_dim"] = inputs_dim for i in remove_inputs: - if i in chunk_info['inputs']: - chunk_info['inputs'].remove(i) - + if i in chunk_info["inputs"]: + chunk_info["inputs"].remove(i) + # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1]) + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) for i in non_chunk_inputs: - if i not in chunk_info['inputs']: + if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) return flow_flag, chunk_info @@ -1371,44 +1373,32 @@ def _get_first_non_single_dim(shape): def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): input_node = chunk_input[0] - out_shape = _get_node_shape(chunk_output) out_str = str(list(out_shape)) - context = ( "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % (out_str, input_node.name, input_node.name, chunk_size) ) context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) - - # node = chunk_input[0] - # node_shape = node.meta["tensor_meta"].shape - # free_shape = [ - # node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape)) - # ] - # chunk_dim = _get_first_non_single_dim(free_shape) - # chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape) - # out_shape = str(list(chunk_output.meta["tensor_meta"].shape)) - - # context = ( - # "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range" - # % (out_shape, node.name, node.name, chunk_size) - # ) - # context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim) - # context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice) return context -def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list): +def _gen_loop_end( + chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list +): chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) + chunk_slice = _gen_chunk_slice_dim( + chunk_outputs_dim, "chunk_idx", chunk_output_shape + ) context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name) - context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") + context += ( + chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + ) # determine if its the last use for chunk input - for chunk_input in (chunk_inputs + chunk_non_compute_inputs): + for chunk_input in chunk_inputs + chunk_non_compute_inputs: if all( [ _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx @@ -1456,10 +1446,7 @@ def _find_chunk_all_input_nodes(nodes: List[Node]): input_nodes = [] for node in nodes: for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - ): + if input_node not in nodes and input_node not in input_nodes: input_nodes.append(input_node) return input_nodes @@ -1549,16 +1536,12 @@ def emit_code_with_chunk( chunk_inputs = [i["inputs"] for i in chunk_search] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] - chunk_inputs_idx = [ - [_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i ] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] - chunk_outputs_idx = [ - _find_idx_by_name(i.name, node_list) for i in chunk_outputs - ] node_idx = 0 region_idx = 0 @@ -1586,7 +1569,9 @@ def emit_code_with_chunk( for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node)) + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) @@ -1604,7 +1589,8 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], node_list + chunk_outputs_dim[region_idx], + node_list, ) ) within_chunk_region = False