mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
31a2c5d09f
commit
b7b67c32ad
|
@ -232,7 +232,7 @@ class FlowTracer(object):
|
|||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in chunk_info['inputs']:
|
||||
for input_node in chunk_info["inputs"]:
|
||||
input_dict = {}
|
||||
for user in input_node.users.keys():
|
||||
if _is_non_compute_node(user):
|
||||
|
@ -252,15 +252,17 @@ class FlowTracer(object):
|
|||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
chunk_info['inputs_dim'] = inputs_dim
|
||||
chunk_info["inputs_dim"] = inputs_dim
|
||||
for i in remove_inputs:
|
||||
if i in chunk_info['inputs']:
|
||||
chunk_info['inputs'].remove(i)
|
||||
if i in chunk_info["inputs"]:
|
||||
chunk_info["inputs"].remove(i)
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1])
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info['inputs']:
|
||||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
return flow_flag, chunk_info
|
||||
|
@ -1371,44 +1373,32 @@ def _get_first_non_single_dim(shape):
|
|||
|
||||
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
||||
input_node = chunk_input[0]
|
||||
|
||||
out_shape = _get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||
)
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
|
||||
# node = chunk_input[0]
|
||||
# node_shape = node.meta["tensor_meta"].shape
|
||||
# free_shape = [
|
||||
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
|
||||
# ]
|
||||
# chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
|
||||
|
||||
# context = (
|
||||
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
|
||||
# % (out_shape, node.name, node.name, chunk_size)
|
||||
# )
|
||||
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list):
|
||||
def _gen_loop_end(
|
||||
chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list
|
||||
):
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||
)
|
||||
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
context += (
|
||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
)
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in (chunk_inputs + chunk_non_compute_inputs):
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all(
|
||||
[
|
||||
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||
|
@ -1456,10 +1446,7 @@ def _find_chunk_all_input_nodes(nodes: List[Node]):
|
|||
input_nodes = []
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
):
|
||||
if input_node not in nodes and input_node not in input_nodes:
|
||||
input_nodes.append(input_node)
|
||||
return input_nodes
|
||||
|
||||
|
@ -1549,16 +1536,12 @@ def emit_code_with_chunk(
|
|||
chunk_inputs = [i["inputs"] for i in chunk_search]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
|
||||
chunk_inputs_idx = [
|
||||
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||
chunk_outputs_idx = [
|
||||
_find_idx_by_name(i.name, node_list) for i in chunk_outputs
|
||||
]
|
||||
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
|
@ -1586,7 +1569,9 @@ def emit_code_with_chunk(
|
|||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node))
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim, "chunk_idx", _get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
|
@ -1604,7 +1589,8 @@ def emit_code_with_chunk(
|
|||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx], node_list
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
)
|
||||
)
|
||||
within_chunk_region = False
|
||||
|
|
Loading…
Reference in New Issue