mirror of https://github.com/hpcaitech/ColossalAI
work with outerproductmean and msa
parent
5de9e46381
commit
31a2c5d09f
264
chunk_codegen.py
264
chunk_codegen.py
|
@ -134,7 +134,7 @@ class FlowTracer(object):
|
|||
return name, i
|
||||
raise RuntimeError("invalid node")
|
||||
|
||||
def get_flow_mix(self, node):
|
||||
def _get_flow_mix_node(self, node):
|
||||
if self._is_non_compute_node(node):
|
||||
return None
|
||||
_, node_trace = self.find_node_flow(node)
|
||||
|
@ -145,7 +145,7 @@ class FlowTracer(object):
|
|||
vars = list(node_trace["outside_depend"][0].values())[0]
|
||||
return vars
|
||||
|
||||
def get_same_flow_node(self, node_list, node):
|
||||
def _get_same_flow_node(self, node_list, node):
|
||||
name, _ = self.find_node_flow(node)
|
||||
result = []
|
||||
for i in self.flow_trace[name]:
|
||||
|
@ -181,13 +181,14 @@ class FlowTracer(object):
|
|||
)
|
||||
return self.flow_trace
|
||||
|
||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = _find_chunk_input_and_output_nodes(
|
||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer):
|
||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||
self.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": start_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
|
@ -197,31 +198,71 @@ class FlowTracer(object):
|
|||
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
node = self.node_list[idx]
|
||||
mix_flow_var = self.get_flow_mix(node)
|
||||
if mix_flow_var is None:
|
||||
mix_flow_node = self._get_flow_mix_node(node)
|
||||
if mix_flow_node is None:
|
||||
continue
|
||||
|
||||
# if there is a flow mix, op must be in [mul, add, div, matmul]
|
||||
# if there is a flow mix, op must be in [mul, add, matmul]
|
||||
# element-wise op requires dim to be equal in every dim
|
||||
if any(n in node.name for n in ["mul", "add"]):
|
||||
for i in node.args:
|
||||
if type(i) == type(mix_flow_var) and i != mix_flow_var:
|
||||
if type(i) == type(mix_flow_node) and i != mix_flow_node:
|
||||
main_flow_var = i
|
||||
# if mix flow is a broadcast in chunk dim,
|
||||
# TODO need to move that flow out of the chunk
|
||||
if mix_flow_var.meta["tensor_meta"].shape[dim_idx] == 1:
|
||||
mix_flow_node_dim = index_tracer._get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, node
|
||||
)
|
||||
if mix_flow_node_dim is None:
|
||||
flow_flag = True
|
||||
for i in self.get_same_flow_node(
|
||||
chunk_info["inputs"], mix_flow_var
|
||||
break
|
||||
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
|
||||
flow_flag = False
|
||||
for i in self._get_same_flow_node(
|
||||
chunk_info["inputs"], mix_flow_node
|
||||
):
|
||||
chunk_info["inputs"].remove(i)
|
||||
# else, we need to chunk mix var as well
|
||||
else:
|
||||
# TODO chunk another value
|
||||
flow_flag = False
|
||||
flow_flag = True
|
||||
break
|
||||
else:
|
||||
raise NotImplementedError("%s not implemented" % node.name)
|
||||
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in chunk_info['inputs']:
|
||||
input_dict = {}
|
||||
for user in input_node.users.keys():
|
||||
if _is_non_compute_node(user):
|
||||
continue
|
||||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||
dim = None
|
||||
if start_dim <= user_idx < end_idx:
|
||||
dim = index_tracer._get_node_chunk_dim(
|
||||
self.node_list[end_idx], end_dim, input_node
|
||||
)
|
||||
elif user_idx == end_idx:
|
||||
dim = end_dim
|
||||
# n has relation with chunk dim
|
||||
if dim is not None and _get_node_shape(user)[dim] != 1:
|
||||
input_dict[user_idx] = dim
|
||||
if len(input_dict) == 0:
|
||||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
chunk_info['inputs_dim'] = inputs_dim
|
||||
for i in remove_inputs:
|
||||
if i in chunk_info['inputs']:
|
||||
chunk_info['inputs'].remove(i)
|
||||
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
non_chunk_inputs = _find_chunk_all_input_nodes(self.node_list[start_idx : end_idx + 1])
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info['inputs']:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
|
||||
return flow_flag, chunk_info
|
||||
|
||||
|
||||
|
@ -367,6 +408,20 @@ class IndexTracer(object):
|
|||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
def _find_source_trace_from_node(self, node):
|
||||
"""
|
||||
Find node source trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
def _find_idx_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
@ -836,6 +891,15 @@ class IndexTracer(object):
|
|||
# return False
|
||||
# return True
|
||||
|
||||
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
class MemoryEstimator(object):
|
||||
def __init__(self) -> None:
|
||||
|
@ -931,8 +995,10 @@ class MemoryEstimator(object):
|
|||
return mem
|
||||
|
||||
def _get_chunk_ratio(self, node, chunk_dim, chunk_size):
|
||||
sorted_dim = sorted(chunk_dim, key=lambda x: list(x.keys())[0])
|
||||
dim = list(sorted_dim[-1].values())[0]
|
||||
shape = node.meta["tensor_meta"].shape
|
||||
chunk_ratio = float(chunk_size) / shape[chunk_dim]
|
||||
chunk_ratio = float(chunk_size) / shape[dim]
|
||||
return chunk_ratio
|
||||
|
||||
def _get_chunk_delete_node_size(
|
||||
|
@ -1157,6 +1223,8 @@ class ChunkRegionSearch(object):
|
|||
return chunk_infos
|
||||
|
||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||
if start_idx == 71 and end_idx == 126:
|
||||
print(1)
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_list[end_idx]
|
||||
|
@ -1188,7 +1256,7 @@ class ChunkRegionSearch(object):
|
|||
continue
|
||||
# detect flow meet
|
||||
flow_flag, chunk_info = self.flow_tracer._detect_flow(
|
||||
start_idx, start_dim, end_idx, end_dim
|
||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||
)
|
||||
if flow_flag:
|
||||
continue
|
||||
|
@ -1301,56 +1369,53 @@ def _get_first_non_single_dim(shape):
|
|||
raise RuntimeError("can not get first non single dim for shape", shape)
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input_meta, chunk_output, chunk_dim, chunk_size=2):
|
||||
if len(chunk_input_meta) == 1:
|
||||
node = chunk_input_meta[0]
|
||||
node_shape = node.meta["tensor_meta"].shape
|
||||
free_shape = [
|
||||
node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
|
||||
]
|
||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||
out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
|
||||
def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2):
|
||||
input_node = chunk_input[0]
|
||||
|
||||
out_shape = _get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
|
||||
% (out_shape, node.name, node.name, chunk_size)
|
||||
)
|
||||
context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||
context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"input with size %d not implemented" % len(chunk_input_meta)
|
||||
)
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||
)
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
|
||||
# node = chunk_input[0]
|
||||
# node_shape = node.meta["tensor_meta"].shape
|
||||
# free_shape = [
|
||||
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
|
||||
# ]
|
||||
# chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
|
||||
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
|
||||
|
||||
# context = (
|
||||
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
|
||||
# % (out_shape, node.name, node.name, chunk_size)
|
||||
# )
|
||||
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
|
||||
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(chunk_outputs, chunk_inputs, node_list, chunk_dim):
|
||||
chunk_inputs_name = chunk_inputs[0].name
|
||||
def _gen_loop_end(chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list):
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
free_shape = [
|
||||
chunk_output_shape[i] if i in chunk_dim else 1
|
||||
for i in range(len(chunk_output_shape))
|
||||
]
|
||||
chunk_dim = _get_first_non_single_dim(free_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", chunk_output_shape)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s\n" % (chunk_slice, chunk_outputs_name)
|
||||
|
||||
context += (
|
||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
users_name = list(chunk_inputs[0].users.keys())
|
||||
if all(
|
||||
[
|
||||
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||
for user in users_name
|
||||
]
|
||||
):
|
||||
context += "; %s = None" % chunk_inputs_name
|
||||
for chunk_input in (chunk_inputs + chunk_non_compute_inputs):
|
||||
if all(
|
||||
[
|
||||
_find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||
for user in chunk_input.users.keys()
|
||||
]
|
||||
):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
context += "\n"
|
||||
return context
|
||||
|
@ -1382,7 +1447,24 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
||||
def _find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
):
|
||||
input_nodes.append(input_node)
|
||||
return input_nodes
|
||||
|
||||
|
||||
def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
|
@ -1410,7 +1492,7 @@ def _find_chunk_input_and_output_nodes(nodes: List[Node]):
|
|||
if (
|
||||
output_node not in nodes
|
||||
and node not in output_nodes
|
||||
and not _is_non_compute_node_except_placeholder(input_node)
|
||||
and not _is_non_compute_node_except_placeholder(output_node)
|
||||
):
|
||||
output_nodes.append(node)
|
||||
|
||||
|
@ -1454,44 +1536,34 @@ def emit_code_with_chunk(
|
|||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
"""
|
||||
|
||||
# find the offload regions
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||
chunk_search = chunk_region_search.search_region()
|
||||
chunk_regions = [i["region"] for i in chunk_search]
|
||||
chunk_dims = [i["dim"] for i in chunk_search]
|
||||
chunk_infos = [i["chunk_info"] for i in chunk_search]
|
||||
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = [[j["inputs"][0] for j in i] for i in chunk_infos]
|
||||
chunk_outputs = [[j["outputs"][0] for j in i] for i in chunk_infos]
|
||||
within_chunk_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
# for idx, (start, end) in enumerate(chunk_regions):
|
||||
# offload_node_list = node_list[start:end + 1]
|
||||
# inputs, outputs = _find_input_and_output_nodes(offload_node_list)
|
||||
# chunk_inputs.append(inputs)
|
||||
# chunk_outputs.append(outputs)
|
||||
# find the chunk regions
|
||||
chunk_region_search = ChunkRegionSearch(meta_graph)
|
||||
chunk_search = chunk_region_search.search_region()
|
||||
|
||||
chunk_regions = [i["region"] for i in chunk_search]
|
||||
chunk_starts = [i[0] for i in chunk_regions]
|
||||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
|
||||
chunk_inputs = [i["inputs"] for i in chunk_search]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
|
||||
chunk_inputs_idx = [
|
||||
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_inputs
|
||||
]
|
||||
chunk_outputs_idx = [
|
||||
[_find_idx_by_name(j.name, node_list) for j in i] for i in chunk_outputs
|
||||
]
|
||||
chunk_inputs_names = []
|
||||
for i in chunk_inputs:
|
||||
for j in i:
|
||||
chunk_inputs_names.append(j.name)
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||
chunk_outputs_idx = [
|
||||
_find_idx_by_name(i.name, node_list) for i in chunk_outputs
|
||||
]
|
||||
|
||||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
||||
while node_idx < len(node_list):
|
||||
node = node_list[node_idx]
|
||||
|
||||
|
@ -1500,21 +1572,24 @@ def emit_code_with_chunk(
|
|||
region_idx = chunk_starts.index(node_idx)
|
||||
|
||||
# add for loop
|
||||
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
|
||||
body.append(
|
||||
_gen_loop_start(
|
||||
chunk_input_meta,
|
||||
node_list[chunk_ends[region_idx]],
|
||||
chunk_dims[region_idx],
|
||||
chunk_inputs[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
)
|
||||
)
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body[-1] = _replace_name(
|
||||
body[-1], chunk_inputs[region_idx][0].name, "chunk_tensor"
|
||||
)
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim, "chunk_idx", _get_node_shape(input_node))
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
|
@ -1526,7 +1601,10 @@ def emit_code_with_chunk(
|
|||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(
|
||||
node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx], node_list
|
||||
)
|
||||
)
|
||||
within_chunk_region = False
|
||||
|
|
Loading…
Reference in New Issue