work with outerproductmean and msa

pull/2364/head
oahzxl 2022-12-12 17:24:06 +08:00
parent 5de9e46381
commit 31a2c5d09f
1 changed files with 171 additions and 93 deletions

View File

@ -134,7 +134,7 @@ class FlowTracer(object):
return name, i
raise RuntimeError("invalid node")
def get_flow_mix(self, node):
def _get_flow_mix_node(self, node):
if self._is_non_compute_node(node):
return None
_, node_trace = self.find_node_flow(node)
@ -145,7 +145,7 @@ class FlowTracer(object):
vars = list(node_trace["outside_depend"][0].values())[0]
return vars
def get_same_flow_node(self, node_list, node):
def _get_same_flow_node(self, node_list, node):
name, _ = self.find_node_flow(node)
result = []
for i in self.flow_trace[name]:
@ -181,13 +181,14 @@ class FlowTracer(object):
)
return self.flow_trace
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = _find_chunk_input_and_output_nodes(
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
"inputs_non_chunk": [],
"inputs_dim": start_dim,
"outputs": outputs,
"outputs_dim": end_dim,
@ -197,31 +198,71 @@ class FlowTracer(object):
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
mix_flow_var = self.get_flow_mix(node)
if mix_flow_var is None:
mix_flow_node = self._get_flow_mix_node(node)
if mix_flow_node is None:
continue
# if there is a flow mix, op must be in [mul, add, div, matmul]
# if there is a flow mix, op must be in [mul, add, matmul]
# element-wise op requires dim to be equal in every dim
if any(n in node.name for n in ["mul", "add"]):
for i in node.args:
if type(i) == type(mix_flow_var) and i != mix_flow_var:
if type(i) == type(mix_flow_node) and i != mix_flow_node:
main_flow_var = i
# if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk
if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1:
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
self.node_list[end_idx], end_dim, node
)
if mix_flow_node_dim is None:
flow_flag = True
for i in self.get_same_flow_node(
chunk_info["inputs"], mix_flow_var
break
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
flow_flag = False
for i in self._get_same_flow_node(
chunk_info["inputs"], mix_flow_node
):
chunk_info["inputs"].remove(i)
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_flag = False
flow_flag = True
break
else:
raise NotImplementedError("%s not implemented" % node.name)
inputs_dim = []
remove_inputs = []
for input_node in chunk_info['inputs']:
input_dict = {}
for user in input_node.users.keys():
if _is_non_compute_node(user):
continue
user_idx = _find_idx_by_name(user.name, self.node_list)
dim = None
if start_dim <= user_idx < end_idx:
dim = index_tracer._get_node_chunk_dim(
self.node_list[end_idx], end_dim, input_node
)
elif user_idx == end_idx:
dim = end_dim
# n has relation with chunk dim
if dim is not None and _get_node_shape(user)[dim] != 1:
input_dict[user_idx] = dim
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
chunk_info['inputs_dim'] = inputs_dim
for i in remove_inputs:
if i in chunk_info['inputs']:
chunk_info['inputs'].remove(i)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1])
for i in non_chunk_inputs:
if i not in chunk_info['inputs']:
chunk_info["inputs_non_chunk"].append(i)
return flow_flag, chunk_info
@ -367,6 +408,20 @@ class IndexTracer(object):
node_dict = self.idx_trace_list[node_idx]
return node_dict
def _find_source_trace_from_node(self, node):
"""
Find node source trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_dict = self.idx_trace_list[node_idx]
return node_dict["source"]
def _find_idx_trace_from_node(self, node):
"""
Find node idx trace by the node.
@ -836,6 +891,15 @@ class IndexTracer(object):
# return False
# return True
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
class MemoryEstimator(object):
def __init__(self) -> None:
@ -931,8 +995,10 @@ class MemoryEstimator(object):
return mem
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0])
dim = list(sorted_dim[-1].values())[0]
shape = node.meta["tensor_meta"].shape
chunk_ratio = float(chunk_size) / shape[chunk_dim]
chunk_ratio = float(chunk_size) / shape[dim]
return chunk_ratio
def _get_chunk_delete_node_size(
@ -1157,6 +1223,8 @@ class ChunkRegionSearch(object):
return chunk_infos
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
if start_idx == 71 and end_idx == 126:
print(1)
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.node_list[end_idx]
@ -1188,7 +1256,7 @@ class ChunkRegionSearch(object):
continue
# detect flow meet
flow_flag, chunk_info = self.flow_tracer._detect_flow(
start_idx, start_dim, end_idx, end_dim
start_idx, start_dim, end_idx, end_dim, self.index_tracer
)
if flow_flag:
continue
@ -1301,56 +1369,53 @@ def _get_first_non_single_dim(shape):
raise RuntimeError("can not get first non single dim for shape", shape)
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2):
if len(chunk_input_meta) == 1:
node = chunk_input_meta[0]
node_shape = node.meta["tensor_meta"].shape
free_shape = [
node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
]
chunk_dim = _get_first_non_single_dim(free_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
input_node = chunk_input[0]
out_shape = _get_node_shape(chunk_output)
out_str = str(list(out_shape))
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
% (out_shape, node.name, node.name, chunk_size)
)
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
else:
raise NotImplementedError(
"input with size %d not implemented" % len(chunk_input_meta)
)
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
% (out_str, input_node.name, input_node.name, chunk_size)
)
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
# node = chunk_input[0]
# node_shape = node.meta["tensor_meta"].shape
# free_shape = [
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
# ]
# chunk_dim = _get_first_non_single_dim(free_shape)
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
# context = (
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
# % (out_shape, node.name, node.name, chunk_size)
# )
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
return context
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim):
chunk_inputs_name = chunk_inputs[0].name
def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list):
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
free_shape = [
chunk_output_shape[i] if i in chunk_dim else 1
for i in range(len(chunk_output_shape))
]
chunk_dim = _get_first_non_single_dim(free_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
context += (
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
# determine if its the last use for chunk input
users_name = list(chunk_inputs[0].users.keys())
if all(
[
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
for user in users_name
]
):
context += "; %s = None" % chunk_inputs_name
for chunk_input in (chunk_inputs + chunk_non_compute_inputs):
if all(
[
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
for user in chunk_input.users.keys()
]
):
context += "; %s = None" % chunk_input.name
context += "\n"
return context
@ -1382,7 +1447,24 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return input_nodes, output_nodes
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
def _find_chunk_all_input_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
for node in nodes:
for input_node in node._input_nodes.keys():
if (
input_node not in nodes
and input_node not in input_nodes
):
input_nodes.append(input_node)
return input_nodes
def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
@ -1410,7 +1492,7 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]):
if (
output_node not in nodes
and node not in output_nodes
and not _is_non_compute_node_except_placeholder(input_node)
and not _is_non_compute_node_except_placeholder(output_node)
):
output_nodes.append(node)
@ -1454,44 +1536,34 @@ def emit_code_with_chunk(
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
# find the offload regions
chunk_region_search = ChunkRegionSearch(meta_graph)
chunk_search = chunk_region_search.search_region()
chunk_regions = [i["region"] for i in chunk_search]
chunk_dims = [i["dim"] for i in chunk_search]
chunk_infos = [i["chunk_info"] for i in chunk_search]
chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos]
chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos]
within_chunk_region = False
node_list = list(nodes)
# find the input and output var names for each offload region
# for idx, (start, end) in enumerate(chunk_regions):
# offload_node_list = node_list[start:end + 1]
# inputs, outputs = _find_input_and_output_nodes(offload_node_list)
# chunk_inputs.append(inputs)
# chunk_outputs.append(outputs)
# find the chunk regions
chunk_region_search = ChunkRegionSearch(meta_graph)
chunk_search = chunk_region_search.search_region()
chunk_regions = [i["region"] for i in chunk_search]
chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_search]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
chunk_inputs_idx = [
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
]
chunk_outputs_idx = [
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs
]
chunk_inputs_names = []
for i in chunk_inputs:
for j in i:
chunk_inputs_names.append(j.name)
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_search]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
chunk_outputs_idx = [
_find_idx_by_name(i.name, node_list) for i in chunk_outputs
]
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx = 0
region_idx = 0
within_chunk_region = False
while node_idx < len(node_list):
node = node_list[node_idx]
@ -1500,21 +1572,24 @@ def emit_code_with_chunk(
region_idx = chunk_starts.index(node_idx)
# add for loop
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
body.append(
_gen_loop_start(
chunk_input_meta,
node_list[chunk_ends[region_idx]],
chunk_dims[region_idx],
chunk_inputs[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
)
)
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
body[-1] = _replace_name(
body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor"
)
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node))
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
@ -1526,7 +1601,10 @@ def emit_code_with_chunk(
if node_idx in chunk_ends:
body.append(
_gen_loop_end(
node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], node_list
)
)
within_chunk_region = False