mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update testspull/2540/head
parent
f477a14f4a
commit
05671fcb42
|
@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
|
||||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||||
|
|
||||||
from .search_chunk import SearchChunk
|
from .search_chunk import SearchChunk
|
||||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
|
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
|
||||||
|
|
||||||
|
|
||||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||||
|
@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
|
||||||
return new_shape
|
return new_shape
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
|
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||||
"""
|
"""
|
||||||
Generate chunk loop start
|
Generate chunk loop start
|
||||||
|
|
||||||
|
@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
|
||||||
context (str): generated str
|
context (str): generated str
|
||||||
"""
|
"""
|
||||||
input_node = chunk_input[0]
|
input_node = chunk_input[0]
|
||||||
out_shape = get_node_shape(chunk_output)
|
|
||||||
out_str = str(list(out_shape))
|
context = ""
|
||||||
context = (
|
for i in range(len(chunk_output)):
|
||||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
|
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||||
(out_str, input_node.name, input_node.name, chunk_size))
|
if get_node_name(chunk_output[i]) == "split":
|
||||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||||
|
input_node.name)
|
||||||
|
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||||
|
tensor_str = "[" + tensor_str[:-2] + "]"
|
||||||
|
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
|
||||||
|
else:
|
||||||
|
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
|
||||||
|
input_node.name, input_node.name)
|
||||||
|
|
||||||
|
out_shape = get_node_shape(chunk_output[0])
|
||||||
|
chunk_shape = out_shape[chunk_ouput_dim[0]]
|
||||||
|
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_end(
|
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
|
||||||
chunk_inputs: List[Node],
|
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
|
||||||
chunk_non_compute_inputs: List[Node],
|
|
||||||
chunk_outputs: Node,
|
|
||||||
chunk_outputs_dim: int,
|
|
||||||
node_list: List[Node],
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Generate chunk loop end
|
Generate chunk loop end
|
||||||
|
|
||||||
|
@ -102,22 +108,13 @@ def _gen_loop_end(
|
||||||
Returns:
|
Returns:
|
||||||
context (str): generated str
|
context (str): generated str
|
||||||
"""
|
"""
|
||||||
chunk_outputs_name = chunk_outputs.name
|
context = "chunk_size = None"
|
||||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
|
||||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
|
||||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
|
||||||
context = " chunk_result%s = %s; %s = None\n" % (
|
|
||||||
chunk_slice,
|
|
||||||
chunk_outputs_name,
|
|
||||||
chunk_outputs_name,
|
|
||||||
)
|
|
||||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
|
||||||
|
|
||||||
# determine if its the last use for chunk input
|
# determine if its the last use for chunk input
|
||||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||||
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||||
context += "; %s = None" % chunk_input.name
|
context += "; %s = None" % chunk_input.name
|
||||||
|
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
|
||||||
|
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
|
||||||
context += "\n"
|
context += "\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
@ -158,7 +155,7 @@ def _replace_ones_like(
|
||||||
add chunk slice for new tensor op such as ones like
|
add chunk slice for new tensor op such as ones like
|
||||||
"""
|
"""
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
source_node = meta_node.args[0].args[0]
|
source_node = meta_node.args[0].args[0]
|
||||||
|
@ -169,21 +166,37 @@ def _replace_ones_like(
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
def _replace_input_node(
|
def _add_node_slice(
|
||||||
chunk_inputs: List[Node],
|
chunk_nodes: List[Node],
|
||||||
region_idx: int,
|
region_idx: int,
|
||||||
chunk_inputs_dim: Dict,
|
chunk_nodes_dim: Dict,
|
||||||
node_idx: int,
|
node_idx: int,
|
||||||
body: List[str],
|
body: List[str],
|
||||||
|
node: Node,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
add chunk slice for input nodes
|
add chunk slice for input nodes
|
||||||
"""
|
"""
|
||||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
|
||||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
# inputs node
|
||||||
|
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
|
||||||
|
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
|
||||||
if idx == node_idx:
|
if idx == node_idx:
|
||||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
|
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
|
||||||
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
|
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||||
|
# outputs node
|
||||||
|
else:
|
||||||
|
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||||
|
get_node_shape(chunk_node))
|
||||||
|
if get_node_name(chunk_node) == "split":
|
||||||
|
split_chunk_slice = ""
|
||||||
|
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||||
|
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||||
|
split_chunk_slice = split_chunk_slice[:-2]
|
||||||
|
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
|
||||||
|
else:
|
||||||
|
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,7 +235,8 @@ def emit_code_with_chunk(
|
||||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||||
|
|
||||||
# chunk outputs
|
# chunk outputs
|
||||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||||
|
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
|
||||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||||
|
|
||||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||||
|
@ -248,7 +262,9 @@ def emit_code_with_chunk(
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
# replace input var with chunk var
|
# replace input var with chunk var
|
||||||
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
|
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
||||||
|
# replace output var with chunk var
|
||||||
|
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
||||||
# ones like
|
# ones like
|
||||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||||
# reassgin reshape size
|
# reassgin reshape size
|
||||||
|
@ -263,13 +279,8 @@ def emit_code_with_chunk(
|
||||||
# generate chunk region end
|
# generate chunk region end
|
||||||
if node_idx in chunk_ends:
|
if node_idx in chunk_ends:
|
||||||
body.append(
|
body.append(
|
||||||
_gen_loop_end(
|
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
|
||||||
chunk_inputs[region_idx],
|
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
|
||||||
chunk_inputs_non_chunk[region_idx],
|
|
||||||
chunk_outputs[region_idx],
|
|
||||||
chunk_outputs_dim[region_idx],
|
|
||||||
node_list,
|
|
||||||
))
|
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
|
@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
|
||||||
|
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
|
|
||||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
|
||||||
|
|
||||||
|
|
||||||
class EstimateMemory(object):
|
class EstimateMemory(object):
|
||||||
|
@ -14,8 +14,8 @@ class EstimateMemory(object):
|
||||||
Estimate memory with chunk
|
Estimate memory with chunk
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||||
pass
|
self.node_mgr = node_mgr
|
||||||
|
|
||||||
def _get_meta_node_size(self, x):
|
def _get_meta_node_size(self, x):
|
||||||
x = x.meta["tensor_meta"]
|
x = x.meta["tensor_meta"]
|
||||||
|
@ -78,7 +78,7 @@ class EstimateMemory(object):
|
||||||
nodes_to_delete = []
|
nodes_to_delete = []
|
||||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||||
chunk_input_users = chunk_input.users.keys()
|
chunk_input_users = chunk_input.users.keys()
|
||||||
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
|
||||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||||
if chunk_input not in nodes_to_delete:
|
if chunk_input not in nodes_to_delete:
|
||||||
nodes_to_delete.append(chunk_input)
|
nodes_to_delete.append(chunk_input)
|
||||||
|
@ -212,7 +212,7 @@ class EstimateMemory(object):
|
||||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ class EstimateMemory(object):
|
||||||
if use_chunk and idx in chunk_starts:
|
if use_chunk and idx in chunk_starts:
|
||||||
chunk_within = True
|
chunk_within = True
|
||||||
chunk_region_idx = chunk_starts.index(idx)
|
chunk_region_idx = chunk_starts.index(idx)
|
||||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||||
|
|
||||||
# determine chunk ratio for current node
|
# determine chunk ratio for current node
|
||||||
if chunk_within:
|
if chunk_within:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .trace_indice import TraceIndice
|
from .trace_indice import TraceIndice
|
||||||
from .utils import find_idx_by_name
|
from .utils import NodeMgr
|
||||||
|
|
||||||
|
|
||||||
class ReorderGraph(object):
|
class ReorderGraph(object):
|
||||||
|
@ -7,31 +7,27 @@ class ReorderGraph(object):
|
||||||
Reorder node list and indice trace list
|
Reorder node list and indice trace list
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||||
self.trace_indice = trace_indice
|
self.trace_indice = trace_indice
|
||||||
self.all_reorder_map = {
|
self.node_mgr = node_mgr
|
||||||
i: i for i in range(len(self.trace_indice.indice_trace_list))
|
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
|
||||||
}
|
|
||||||
|
|
||||||
def _get_reorder_map(self, chunk_info):
|
def _get_reorder_map(self, chunk_info):
|
||||||
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
|
||||||
|
|
||||||
chunk_region_start = chunk_info["region"][0]
|
chunk_region_start = chunk_info["region"][0]
|
||||||
chunk_region_end = chunk_info["region"][1]
|
chunk_region_end = chunk_info["region"][1]
|
||||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||||
chunk_prepose_nodes_idx = [
|
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
|
||||||
find_idx_by_name(i.name, self.trace_indice.node_list)
|
|
||||||
for i in chunk_prepose_nodes
|
|
||||||
]
|
|
||||||
# put prepose nodes ahead
|
# put prepose nodes ahead
|
||||||
for idx, n in enumerate(chunk_prepose_nodes):
|
for idx, n in enumerate(chunk_prepose_nodes):
|
||||||
n_idx = chunk_prepose_nodes_idx[idx]
|
n_idx = chunk_prepose_nodes_idx[idx]
|
||||||
reorder_map[n_idx] = chunk_region_start + idx
|
reorder_map[n_idx] = chunk_region_start + idx
|
||||||
# put other nodes after prepose nodes
|
# put other nodes after prepose nodes
|
||||||
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
|
||||||
if n in chunk_prepose_nodes:
|
if n in chunk_prepose_nodes:
|
||||||
continue
|
continue
|
||||||
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
n_idx = self.node_mgr.find_node_idx(n)
|
||||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||||
reorder_map[n_idx] = n_idx + pos
|
reorder_map[n_idx] = n_idx + pos
|
||||||
|
|
||||||
|
@ -44,7 +40,7 @@ class ReorderGraph(object):
|
||||||
chunk_info["region"][1],
|
chunk_info["region"][1],
|
||||||
)
|
)
|
||||||
new_inputs_dim = []
|
new_inputs_dim = []
|
||||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
for _, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||||
new_input_dim = {}
|
new_input_dim = {}
|
||||||
for k, v in input_dim.items():
|
for k, v in input_dim.items():
|
||||||
new_input_dim[reorder_map[k]] = v
|
new_input_dim[reorder_map[k]] = v
|
||||||
|
@ -57,16 +53,14 @@ class ReorderGraph(object):
|
||||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||||
|
|
||||||
def _reorder_self_node_list(self, reorder_map):
|
def _reorder_self_node_list(self, reorder_map):
|
||||||
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
|
||||||
self.trace_indice.node_list = new_node_list
|
self.node_mgr.update_node_list(new_node_list)
|
||||||
|
|
||||||
def _reorder_idx_trace(self, reorder_map):
|
def _reorder_idx_trace(self, reorder_map):
|
||||||
# reorder list
|
# reorder list
|
||||||
new_idx_trace_list = [
|
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
|
||||||
None for _ in range(len(self.trace_indice.indice_trace_list))
|
|
||||||
]
|
|
||||||
for old_idx, new_idx in reorder_map.items():
|
for old_idx, new_idx in reorder_map.items():
|
||||||
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||||
self.trace_indice.indice_trace_list = new_idx_trace_list
|
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||||
|
|
|
@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
|
||||||
from .trace_flow import TraceFlow
|
from .trace_flow import TraceFlow
|
||||||
from .trace_indice import TraceIndice
|
from .trace_indice import TraceIndice
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
NodeMgr,
|
||||||
find_chunk_compute_input_and_output_nodes,
|
find_chunk_compute_input_and_output_nodes,
|
||||||
get_logger,
|
get_logger,
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
|
@ -49,15 +50,17 @@ class SearchChunk(object):
|
||||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||||
self.print_mem = print_mem
|
self.print_mem = print_mem
|
||||||
self.print_progress = print_progress
|
self.print_progress = print_progress
|
||||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
self.node_mgr = NodeMgr(gm)
|
||||||
self.estimate_memory = EstimateMemory()
|
self.trace_indice = TraceIndice(self.node_mgr)
|
||||||
|
self.estimate_memory = EstimateMemory(self.node_mgr)
|
||||||
self._init_trace()
|
self._init_trace()
|
||||||
self.trace_flow = TraceFlow(self.trace_indice)
|
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
|
||||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
|
||||||
self.select_chunk = SelectChunk(
|
self.select_chunk = SelectChunk(
|
||||||
self.trace_indice,
|
self.trace_indice,
|
||||||
self.estimate_memory,
|
self.estimate_memory,
|
||||||
self.reorder_graph,
|
self.reorder_graph,
|
||||||
|
self.node_mgr,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,7 +70,7 @@ class SearchChunk(object):
|
||||||
reduce the computation complexity of trace_indice
|
reduce the computation complexity of trace_indice
|
||||||
"""
|
"""
|
||||||
# find all max ranges
|
# find all max ranges
|
||||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
|
||||||
cur_node_idx = len(self._get_free_var_idx())
|
cur_node_idx = len(self._get_free_var_idx())
|
||||||
max_chunk_region_list = []
|
max_chunk_region_list = []
|
||||||
while True:
|
while True:
|
||||||
|
@ -100,7 +103,7 @@ class SearchChunk(object):
|
||||||
free_var_idx (List): all indexs of free vars
|
free_var_idx (List): all indexs of free vars
|
||||||
"""
|
"""
|
||||||
free_var_idx = []
|
free_var_idx = []
|
||||||
for idx, n in enumerate(self.trace_indice.node_list):
|
for idx, n in enumerate(self.node_mgr.get_node_list()):
|
||||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||||
free_var_idx.append(idx)
|
free_var_idx.append(idx)
|
||||||
return free_var_idx
|
return free_var_idx
|
||||||
|
@ -164,6 +167,44 @@ class SearchChunk(object):
|
||||||
chunk_region_end = region[0] - 1
|
chunk_region_end = region[0] - 1
|
||||||
return chunk_region_start, chunk_region_end
|
return chunk_region_start, chunk_region_end
|
||||||
|
|
||||||
|
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||||
|
"""
|
||||||
|
Find chunk info for a region.
|
||||||
|
|
||||||
|
We are given the region start and region end, and need to find out all chunk info for it.
|
||||||
|
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||||
|
which is linked in a flow and not computed.
|
||||||
|
If found, we then search flow in the whole region to find out all chunk infos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_trace (List): node's input trace in region
|
||||||
|
output_trace (List): node's output trace in region
|
||||||
|
start_idx (int): region start node index
|
||||||
|
end_idx (int): region end node index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_infos: possible regions found
|
||||||
|
"""
|
||||||
|
start_traces = input_trace[start_idx]
|
||||||
|
if len(start_traces) > 1: # TODO need to be removed
|
||||||
|
return []
|
||||||
|
end_trace = output_trace[end_idx]
|
||||||
|
end_node = self.node_mgr.get_node_by_idx(end_idx)
|
||||||
|
|
||||||
|
chunk_infos = []
|
||||||
|
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||||
|
for start_node, start_trace in start_traces.items():
|
||||||
|
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||||
|
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
|
||||||
|
end_idx):
|
||||||
|
continue
|
||||||
|
# flow search
|
||||||
|
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||||
|
if chunk_info is None:
|
||||||
|
continue
|
||||||
|
chunk_infos.append(chunk_info)
|
||||||
|
return chunk_infos
|
||||||
|
|
||||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||||
"""
|
"""
|
||||||
Search every possible region within the max chunk region.
|
Search every possible region within the max chunk region.
|
||||||
|
@ -178,7 +219,7 @@ class SearchChunk(object):
|
||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||||
input_trace = [] # trace of a node's input nodes
|
input_trace = [] # trace of a node's input nodes
|
||||||
for _, n in enumerate(self.trace_indice.node_list):
|
for _, n in enumerate(self.node_mgr.get_node_list()):
|
||||||
cur_trace = {}
|
cur_trace = {}
|
||||||
for arg in n.args:
|
for arg in n.args:
|
||||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||||
|
@ -188,11 +229,11 @@ class SearchChunk(object):
|
||||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
# skip non compute nodes
|
# skip non compute nodes
|
||||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||||
self.trace_indice.node_list[end_idx]):
|
self.node_mgr.get_node_by_idx(end_idx)):
|
||||||
continue
|
continue
|
||||||
# select free dim
|
# select free dim
|
||||||
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||||
if len(chunk_info) > 0:
|
if len(chunk_info) > 0:
|
||||||
possible_chunk_region.extend(chunk_info)
|
possible_chunk_region.extend(chunk_info)
|
||||||
return possible_chunk_region
|
return possible_chunk_region
|
||||||
|
@ -254,7 +295,7 @@ class SearchChunk(object):
|
||||||
init_mem_peak,
|
init_mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -267,7 +308,7 @@ class SearchChunk(object):
|
||||||
mem_peak,
|
mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
|
||||||
|
|
||||||
if self.print_progress:
|
if self.print_progress:
|
||||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||||
|
@ -277,5 +318,7 @@ class SearchChunk(object):
|
||||||
break
|
break
|
||||||
if self.print_mem:
|
if self.print_mem:
|
||||||
self.print_mem = False
|
self.print_mem = False
|
||||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||||
|
chunk_infos,
|
||||||
|
print_mem=True)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .estimate_memory import EstimateMemory
|
from .estimate_memory import EstimateMemory
|
||||||
from .reorder_graph import ReorderGraph
|
from .reorder_graph import ReorderGraph
|
||||||
from .trace_indice import TraceIndice
|
from .trace_indice import TraceIndice
|
||||||
from .utils import is_non_compute_node
|
from .utils import NodeMgr, is_non_compute_node
|
||||||
|
|
||||||
|
|
||||||
class SelectChunk(object):
|
class SelectChunk(object):
|
||||||
|
@ -11,11 +11,13 @@ class SelectChunk(object):
|
||||||
trace_indice: TraceIndice,
|
trace_indice: TraceIndice,
|
||||||
estimate_memory: EstimateMemory,
|
estimate_memory: EstimateMemory,
|
||||||
reorder_graph: ReorderGraph,
|
reorder_graph: ReorderGraph,
|
||||||
|
node_mgr: NodeMgr,
|
||||||
max_memory=None,
|
max_memory=None,
|
||||||
):
|
):
|
||||||
self.trace_indice = trace_indice
|
self.trace_indice = trace_indice
|
||||||
self.estimate_memory = estimate_memory
|
self.estimate_memory = estimate_memory
|
||||||
self.reorder_graph = reorder_graph
|
self.reorder_graph = reorder_graph
|
||||||
|
self.node_mgr = node_mgr
|
||||||
if max_memory is not None:
|
if max_memory is not None:
|
||||||
self.stratge = "fit_memory"
|
self.stratge = "fit_memory"
|
||||||
self.max_memory = max_memory # MB
|
self.max_memory = max_memory # MB
|
||||||
|
@ -68,7 +70,7 @@ class SelectChunk(object):
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||||
|
@ -134,7 +136,7 @@ class SelectChunk(object):
|
||||||
|
|
||||||
def _get_compute_node_num(self, start, end):
|
def _get_compute_node_num(self, start, end):
|
||||||
count = 0
|
count = 0
|
||||||
for i in self.trace_indice.node_list[start:end + 1]:
|
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
|
||||||
if not is_non_compute_node(i):
|
if not is_non_compute_node(i):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
@ -161,7 +163,7 @@ class SelectChunk(object):
|
||||||
regions_dict_list = []
|
regions_dict_list = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_region = region.copy()
|
cur_region = region.copy()
|
||||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||||
cur_chunk_infos = chunk_infos + [cur_region]
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||||
|
|
|
@ -4,9 +4,10 @@ from torch.fx.node import Node
|
||||||
|
|
||||||
from .trace_indice import TraceIndice
|
from .trace_indice import TraceIndice
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
NodeMgr,
|
||||||
find_chunk_all_input_nodes,
|
find_chunk_all_input_nodes,
|
||||||
find_chunk_compute_input_and_output_nodes,
|
find_chunk_compute_input_and_output_nodes,
|
||||||
find_idx_by_name,
|
find_tensor_shape_node,
|
||||||
flat_list,
|
flat_list,
|
||||||
get_node_name,
|
get_node_name,
|
||||||
get_node_shape,
|
get_node_shape,
|
||||||
|
@ -16,8 +17,9 @@ from .utils import (
|
||||||
|
|
||||||
class TraceFlow(object):
|
class TraceFlow(object):
|
||||||
|
|
||||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||||
self.trace_indice = trace_indice
|
self.trace_indice = trace_indice
|
||||||
|
self.node_mgr = node_mgr
|
||||||
|
|
||||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
"""
|
"""
|
||||||
|
@ -31,7 +33,8 @@ class TraceFlow(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if check pass
|
bool: True if check pass
|
||||||
"""
|
"""
|
||||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
# we use start_node_idx instead of real chunk index
|
||||||
|
start_node_idx = self.node_mgr.find_node_idx(start_node)
|
||||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||||
|
@ -39,7 +42,7 @@ class TraceFlow(object):
|
||||||
if node_idx == start_node_idx and start_dim in node_dim:
|
if node_idx == start_node_idx and start_dim in node_dim:
|
||||||
return True
|
return True
|
||||||
# it means we meet a node outside the loop, and the node is not input node
|
# it means we meet a node outside the loop, and the node is not input node
|
||||||
if node_idx < start_idx:
|
if node_idx < start_node_idx:
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -61,29 +64,12 @@ class TraceFlow(object):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
|
||||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
|
||||||
dim_source = node_from_source[node_from_dim]
|
|
||||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
|
||||||
for k, v in dim_source.items():
|
|
||||||
if k == node_to_idx:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
|
||||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
|
||||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
|
||||||
for node_dim in range(len(get_node_shape(node))):
|
|
||||||
if (input_node_idx in node_trace_source[node_dim]
|
|
||||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
|
||||||
return node_dim
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _assgin_single_node_flow(
|
def _assgin_single_node_flow(
|
||||||
self,
|
self,
|
||||||
arg_node: Node,
|
arg_node: Node,
|
||||||
start_idx: int,
|
start_idx: int,
|
||||||
end_idx: int,
|
end_idx: int,
|
||||||
|
cur_node: Node,
|
||||||
cur_node_dim: int,
|
cur_node_dim: int,
|
||||||
cur_node_compute: Dict,
|
cur_node_compute: Dict,
|
||||||
cur_node_source: Dict,
|
cur_node_source: Dict,
|
||||||
|
@ -109,7 +95,7 @@ class TraceFlow(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this node can be added to the flow, vice versa.
|
bool: True if this node can be added to the flow, vice versa.
|
||||||
"""
|
"""
|
||||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
arg_idx = self.node_mgr.find_node_idx(arg_node)
|
||||||
# arg in chunk range or be inputs
|
# arg in chunk range or be inputs
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
return True
|
return True
|
||||||
|
@ -126,6 +112,11 @@ class TraceFlow(object):
|
||||||
# chunk dim should be None if shape size is 1
|
# chunk dim should be None if shape size is 1
|
||||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
# chunk shape should equal cur node
|
||||||
|
elif get_node_shape(arg_node)[arg_dim] != 1:
|
||||||
|
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
|
||||||
|
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
|
||||||
|
@ -150,7 +141,7 @@ class TraceFlow(object):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
while len(cur_node_list) > 0:
|
while len(cur_node_list) > 0:
|
||||||
|
@ -178,6 +169,7 @@ class TraceFlow(object):
|
||||||
arg,
|
arg,
|
||||||
start_idx,
|
start_idx,
|
||||||
end_idx,
|
end_idx,
|
||||||
|
cur_node,
|
||||||
cur_node_chunk_dim,
|
cur_node_chunk_dim,
|
||||||
cur_node_compute,
|
cur_node_compute,
|
||||||
cur_node_source,
|
cur_node_source,
|
||||||
|
@ -194,7 +186,7 @@ class TraceFlow(object):
|
||||||
for arg in arg_list:
|
for arg in arg_list:
|
||||||
if get_node_shape(arg) is None:
|
if get_node_shape(arg) is None:
|
||||||
continue
|
continue
|
||||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
|
||||||
continue
|
continue
|
||||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||||
|
@ -232,7 +224,7 @@ class TraceFlow(object):
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
for input_node in inputs:
|
for input_node in inputs:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
input_node_idx = self.node_mgr.find_node_idx(input_node)
|
||||||
for user in input_node.users.keys():
|
for user in input_node.users.keys():
|
||||||
# skip non compute
|
# skip non compute
|
||||||
if is_non_compute_node(user):
|
if is_non_compute_node(user):
|
||||||
|
@ -240,7 +232,7 @@ class TraceFlow(object):
|
||||||
# untraced node, mostly non compute
|
# untraced node, mostly non compute
|
||||||
if user not in all_node_info:
|
if user not in all_node_info:
|
||||||
continue
|
continue
|
||||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
user_idx = self.node_mgr.find_node_idx(user)
|
||||||
if start_idx <= user_idx <= end_idx:
|
if start_idx <= user_idx <= end_idx:
|
||||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||||
if chunk_dim is not None:
|
if chunk_dim is not None:
|
||||||
|
@ -262,7 +254,7 @@ class TraceFlow(object):
|
||||||
inputs.remove(i)
|
inputs.remove(i)
|
||||||
return inputs, inputs_dim
|
return inputs, inputs_dim
|
||||||
|
|
||||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
|
||||||
"""
|
"""
|
||||||
get all useless nodes in chunk region and prepose them
|
get all useless nodes in chunk region and prepose them
|
||||||
|
|
||||||
|
@ -279,8 +271,11 @@ class TraceFlow(object):
|
||||||
for node, node_info in all_node_info.items():
|
for node, node_info in all_node_info.items():
|
||||||
if node_info["chunk_dim"] is None:
|
if node_info["chunk_dim"] is None:
|
||||||
maybe_prepose_nodes.append(node)
|
maybe_prepose_nodes.append(node)
|
||||||
|
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
|
||||||
|
if node not in all_node_info and node not in chunk_info["outputs"]:
|
||||||
|
maybe_prepose_nodes.append(node)
|
||||||
maybe_prepose_nodes.sort(
|
maybe_prepose_nodes.sort(
|
||||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
key=lambda x: self.node_mgr.find_node_idx(x),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
) # from last node to first node
|
) # from last node to first node
|
||||||
prepose_nodes = []
|
prepose_nodes = []
|
||||||
|
@ -303,8 +298,7 @@ class TraceFlow(object):
|
||||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||||
continue
|
continue
|
||||||
# out of loop
|
# out of loop
|
||||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
|
||||||
end_idx):
|
|
||||||
continue
|
continue
|
||||||
# compute op in loop
|
# compute op in loop
|
||||||
elif cur_prepose_node_arg in all_node_info:
|
elif cur_prepose_node_arg in all_node_info:
|
||||||
|
@ -328,13 +322,12 @@ class TraceFlow(object):
|
||||||
if n in maybe_prepose_nodes:
|
if n in maybe_prepose_nodes:
|
||||||
maybe_prepose_nodes.remove(n)
|
maybe_prepose_nodes.remove(n)
|
||||||
# sort by index
|
# sort by index
|
||||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
|
||||||
|
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||||
return prepose_nodes
|
|
||||||
|
|
||||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
for n in chunk_info["args"]["prepose_nodes"]:
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
chunk_node_list.remove(n)
|
chunk_node_list.remove(n)
|
||||||
|
@ -345,34 +338,41 @@ class TraceFlow(object):
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
# only single ouput
|
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
|
||||||
if len(outputs) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get every node's chunk dim and fix dim
|
# get every node's chunk dim and fix dim
|
||||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||||
if all_node_info is None:
|
if all_node_info is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# get input nodes' chunk dim
|
|
||||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
|
||||||
if inputs is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
chunk_info = {
|
chunk_info = {
|
||||||
"region": (start_idx, end_idx),
|
"region": (start_idx, end_idx),
|
||||||
"inputs": inputs,
|
"inputs": [],
|
||||||
"inputs_non_chunk": [],
|
"inputs_non_chunk": [],
|
||||||
"inputs_dim": inputs_dim,
|
"inputs_dim": [],
|
||||||
"outputs": outputs,
|
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
|
||||||
"outputs_dim": end_dim,
|
"outputs_non_tensor": {},
|
||||||
|
"outputs_dim": [end_dim],
|
||||||
"node_chunk_dim": all_node_info,
|
"node_chunk_dim": all_node_info,
|
||||||
"args": {},
|
"args": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# find chunk info for other outputs
|
||||||
|
if len(find_tensor_shape_node(outputs)) > 1:
|
||||||
|
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
|
||||||
|
if chunk_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get input nodes' chunk dim
|
||||||
|
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
chunk_info["inputs"] = inputs
|
||||||
|
chunk_info["inputs_dim"] = inputs_dim
|
||||||
|
|
||||||
# move useless nodes ahead of loop
|
# move useless nodes ahead of loop
|
||||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
|
||||||
|
|
||||||
# find non chunk inputs
|
# find non chunk inputs
|
||||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||||
|
@ -382,6 +382,63 @@ class TraceFlow(object):
|
||||||
|
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
|
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
|
||||||
|
chunk_info: Dict):
|
||||||
|
start_node = self.node_mgr.get_node_by_idx(start_idx)
|
||||||
|
# loop all outputs
|
||||||
|
for output in outputs:
|
||||||
|
output_legal = False
|
||||||
|
output_idx = self.node_mgr.find_node_idx(output)
|
||||||
|
# skip the origin output
|
||||||
|
if output_idx == end_idx:
|
||||||
|
continue
|
||||||
|
# skip non tensor
|
||||||
|
if get_node_shape(output) is None:
|
||||||
|
# log shape tensor
|
||||||
|
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
|
||||||
|
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
|
||||||
|
continue
|
||||||
|
# loop every dim of outputs, try to find a legal one
|
||||||
|
for output_dim in range(len(get_node_shape(output))):
|
||||||
|
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
|
||||||
|
continue
|
||||||
|
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
|
||||||
|
if new_all_node_info is None:
|
||||||
|
continue
|
||||||
|
# check node info legal
|
||||||
|
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
|
||||||
|
output_legal = True
|
||||||
|
break
|
||||||
|
# not legal
|
||||||
|
if output_legal == False:
|
||||||
|
return None
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
|
||||||
|
"""
|
||||||
|
check if there is conflict between new node info and old chunk info. If not, update old chunk info
|
||||||
|
"""
|
||||||
|
# check if conflict
|
||||||
|
overlap_flag = False
|
||||||
|
for k, v in new_all_node_info.items():
|
||||||
|
if k in chunk_info["node_chunk_dim"]:
|
||||||
|
overlap_flag = True
|
||||||
|
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
|
||||||
|
return False
|
||||||
|
# if no overlap, we just consider them as prepose nodes, instead of new output
|
||||||
|
if overlap_flag == False:
|
||||||
|
return True
|
||||||
|
# update chunk info
|
||||||
|
for k, v in new_all_node_info.items():
|
||||||
|
if k in chunk_info["node_chunk_dim"]:
|
||||||
|
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
|
||||||
|
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
|
||||||
|
else:
|
||||||
|
chunk_info["node_chunk_dim"][k] = v
|
||||||
|
chunk_info["outputs"].append(output)
|
||||||
|
chunk_info["outputs_dim"].append(output_dim)
|
||||||
|
return True
|
||||||
|
|
||||||
def _reassgin_reshape_size(self, chunk_info):
|
def _reassgin_reshape_size(self, chunk_info):
|
||||||
"""
|
"""
|
||||||
Some shape args in reshape may have changed due to chunk
|
Some shape args in reshape may have changed due to chunk
|
||||||
|
@ -389,10 +446,17 @@ class TraceFlow(object):
|
||||||
"""
|
"""
|
||||||
chunk_region = chunk_info["region"]
|
chunk_region = chunk_info["region"]
|
||||||
reshape_size = {}
|
reshape_size = {}
|
||||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
|
||||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
|
||||||
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
||||||
|
if node in chunk_info["args"]["prepose_nodes"]:
|
||||||
|
continue
|
||||||
|
if node.args[0] in chunk_info["inputs_non_chunk"]:
|
||||||
|
continue
|
||||||
reshape_args = flat_list(node.args[1:])
|
reshape_args = flat_list(node.args[1:])
|
||||||
|
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
|
||||||
|
reshape_args[0].meta['fwd_out']) > 1:
|
||||||
|
continue
|
||||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||||
new_shape = ""
|
new_shape = ""
|
||||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||||
|
@ -409,44 +473,7 @@ class TraceFlow(object):
|
||||||
chunk_info["reshape_size"] = reshape_size
|
chunk_info["reshape_size"] = reshape_size
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||||
"""
|
|
||||||
Find chunk info for a region.
|
|
||||||
|
|
||||||
We are given the region start and region end, and need to find out all chunk info for it.
|
|
||||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
|
||||||
which is linked in a flow and not computed.
|
|
||||||
If found, we then search flow in the whole region to find out all chunk infos.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_trace (List): node's input trace in region
|
|
||||||
output_trace (List): node's output trace in region
|
|
||||||
start_idx (int): region start node index
|
|
||||||
end_idx (int): region end node index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
chunk_infos: possible regions found
|
|
||||||
"""
|
|
||||||
start_traces = input_trace[start_idx]
|
|
||||||
if len(start_traces) > 1: # TODO need to be removed
|
|
||||||
return []
|
|
||||||
end_trace = output_trace[end_idx]
|
|
||||||
end_node = self.trace_indice.node_list[end_idx]
|
|
||||||
|
|
||||||
chunk_infos = []
|
|
||||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
|
||||||
for start_node, start_trace in start_traces.items():
|
|
||||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
|
||||||
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
|
|
||||||
continue
|
|
||||||
# flow search
|
|
||||||
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
|
|
||||||
if chunk_info is None:
|
|
||||||
continue
|
|
||||||
chunk_infos.append(chunk_info)
|
|
||||||
return chunk_infos
|
|
||||||
|
|
||||||
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
|
||||||
end_idx: int) -> bool:
|
end_idx: int) -> bool:
|
||||||
"""
|
"""
|
||||||
check if region start and end is legal
|
check if region start and end is legal
|
||||||
|
|
|
@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from .utils import (
|
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
|
||||||
find_first_tensor_arg,
|
|
||||||
find_idx_by_name,
|
|
||||||
flat_list,
|
|
||||||
get_module_node_name,
|
|
||||||
get_node_name,
|
|
||||||
get_node_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceIndice(object):
|
class TraceIndice(object):
|
||||||
|
@ -35,8 +28,8 @@ class TraceIndice(object):
|
||||||
node_list (List)
|
node_list (List)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_list: List[Node]) -> None:
|
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||||
self.node_list = node_list
|
self.node_mgr = node_mgr
|
||||||
self.indice_trace_list = self._init_indice_trace_list()
|
self.indice_trace_list = self._init_indice_trace_list()
|
||||||
self.indice_view_list = {}
|
self.indice_view_list = {}
|
||||||
self.indice_count = -1
|
self.indice_count = -1
|
||||||
|
@ -45,7 +38,7 @@ class TraceIndice(object):
|
||||||
|
|
||||||
def _init_indice_trace_list(self) -> List:
|
def _init_indice_trace_list(self) -> List:
|
||||||
indice_trace_list = []
|
indice_trace_list = []
|
||||||
for n in self.node_list:
|
for n in self.node_mgr.get_node_list():
|
||||||
if get_node_shape(n) != None:
|
if get_node_shape(n) != None:
|
||||||
cur_trace = {
|
cur_trace = {
|
||||||
"indice": [None for _ in range(len(get_node_shape(n)))],
|
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||||
|
@ -99,7 +92,7 @@ class TraceIndice(object):
|
||||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||||
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
node_from_idx = self.node_mgr.find_node_idx(node_from)
|
||||||
if init:
|
if init:
|
||||||
node_to_trace_source[node_to_dim] = {}
|
node_to_trace_source[node_to_dim] = {}
|
||||||
# add dim to cur new source
|
# add dim to cur new source
|
||||||
|
@ -200,7 +193,7 @@ class TraceIndice(object):
|
||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
node_idx = self.node_mgr.find_node_idx(node)
|
||||||
node_dict = self.indice_trace_list[node_idx]
|
node_dict = self.indice_trace_list[node_idx]
|
||||||
return node_dict
|
return node_dict
|
||||||
|
|
||||||
|
@ -214,7 +207,7 @@ class TraceIndice(object):
|
||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
node_idx = self.node_mgr.find_node_idx(node)
|
||||||
node_dict = self.indice_trace_list[node_idx]
|
node_dict = self.indice_trace_list[node_idx]
|
||||||
return node_dict["source"]
|
return node_dict["source"]
|
||||||
|
|
||||||
|
@ -227,7 +220,7 @@ class TraceIndice(object):
|
||||||
Returns:
|
Returns:
|
||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
"""
|
"""
|
||||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
node_idx = self.node_mgr.find_node_idx(node)
|
||||||
return self.indice_trace_list[node_idx]["indice"]
|
return self.indice_trace_list[node_idx]["indice"]
|
||||||
|
|
||||||
def _find_compute_trace_from_node(self, node: Node) -> List:
|
def _find_compute_trace_from_node(self, node: Node) -> List:
|
||||||
|
@ -239,7 +232,7 @@ class TraceIndice(object):
|
||||||
Returns:
|
Returns:
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
node_idx = self.node_mgr.find_node_idx(node)
|
||||||
return self.indice_trace_list[node_idx]["compute"]
|
return self.indice_trace_list[node_idx]["compute"]
|
||||||
|
|
||||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
|
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
|
||||||
|
@ -454,8 +447,6 @@ class TraceIndice(object):
|
||||||
node (node)
|
node (node)
|
||||||
node_idx (int)
|
node_idx (int)
|
||||||
"""
|
"""
|
||||||
for _ in range(len(get_node_shape(node.args[0]))):
|
|
||||||
self._add_dim(node_idx, 0)
|
|
||||||
self._assign_indice_as_input(node, node_idx)
|
self._assign_indice_as_input(node, node_idx)
|
||||||
dim_idx = node.kwargs["dim"]
|
dim_idx = node.kwargs["dim"]
|
||||||
self._del_dim(node_idx, dim_idx)
|
self._del_dim(node_idx, dim_idx)
|
||||||
|
@ -702,21 +693,20 @@ class TraceIndice(object):
|
||||||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||||
and view_dict["dim_from"] == dim_to):
|
and view_dict["dim_from"] == dim_to):
|
||||||
# inheirt indice from current node
|
# inheirt indice from current node
|
||||||
for dim_to_i in dim_to:
|
if len_diff == 1:
|
||||||
for dim_from_i in dim_from:
|
if origin_shape[dim_from[0]] == 1:
|
||||||
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
|
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||||
|
elif origin_shape[dim_from[1]] == 1:
|
||||||
|
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||||
|
elif len_diff == -1:
|
||||||
|
if target_shape[dim_to[0]] == 1:
|
||||||
|
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||||
|
elif target_shape[dim_to[1]] == 1:
|
||||||
|
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||||
# inherid indice from input node of last view
|
# inherid indice from input node of last view
|
||||||
for dim_to_i in dim_to:
|
for dim_to_i in dim_to:
|
||||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||||
|
|
||||||
# inherit computation
|
|
||||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
|
||||||
for i in dim_from:
|
|
||||||
if origin_trace[i] in compute_log:
|
|
||||||
for j in dim_to:
|
|
||||||
self._mark_computation(node, node_idx, [j])
|
|
||||||
break
|
|
||||||
|
|
||||||
# log view, not used now
|
# log view, not used now
|
||||||
view_dict = {
|
view_dict = {
|
||||||
"idx_from": [origin_trace[i] for i in dim_from],
|
"idx_from": [origin_trace[i] for i in dim_from],
|
||||||
|
@ -742,7 +732,7 @@ class TraceIndice(object):
|
||||||
|
|
||||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||||
active_nodes = set(flat_list(active_nodes))
|
active_nodes = set(flat_list(active_nodes))
|
||||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
|
||||||
for i in range(trace_range[0], trace_range[1] + 1):
|
for i in range(trace_range[0], trace_range[1] + 1):
|
||||||
trace = self.indice_trace_list[i]
|
trace = self.indice_trace_list[i]
|
||||||
# clear compute
|
# clear compute
|
||||||
|
@ -758,7 +748,7 @@ class TraceIndice(object):
|
||||||
dim_source.pop(k)
|
dim_source.pop(k)
|
||||||
|
|
||||||
def trace_indice(self) -> None:
|
def trace_indice(self) -> None:
|
||||||
for idx, node in enumerate(self.node_list):
|
for idx, node in enumerate(self.node_mgr.get_node_list()):
|
||||||
node_name = get_node_name(node)
|
node_name = get_node_name(node)
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
self._assign_all_indice(node, idx)
|
self._assign_all_indice(node, idx)
|
||||||
|
|
|
@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class NodeMgr(object):
|
||||||
|
|
||||||
|
def __init__(self, gm) -> None:
|
||||||
|
self._node_list = list(gm.graph.nodes)
|
||||||
|
self._node_dict = {}
|
||||||
|
self._set_node_dict()
|
||||||
|
|
||||||
|
def _set_node_dict(self) -> None:
|
||||||
|
"""
|
||||||
|
create a dict {node_name: node_idx}
|
||||||
|
"""
|
||||||
|
self._node_dict.clear()
|
||||||
|
for idx, node in enumerate(self._node_list):
|
||||||
|
self._node_dict[node.name] = idx
|
||||||
|
|
||||||
|
def find_node_idx(self, node: Node) -> int:
|
||||||
|
"""
|
||||||
|
find node's index
|
||||||
|
"""
|
||||||
|
return self._node_dict[node.name]
|
||||||
|
|
||||||
|
def find_node_idx_by_name(self, node_name: str) -> int:
|
||||||
|
"""
|
||||||
|
find node's index
|
||||||
|
"""
|
||||||
|
return self._node_dict[node_name]
|
||||||
|
|
||||||
|
def get_node_by_idx(self, idx: int) -> Node:
|
||||||
|
"""
|
||||||
|
get a node by index
|
||||||
|
"""
|
||||||
|
return self._node_list[idx]
|
||||||
|
|
||||||
|
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
|
||||||
|
"""
|
||||||
|
get a slice of node by index
|
||||||
|
"""
|
||||||
|
return self._node_list[start:end]
|
||||||
|
|
||||||
|
def get_node_list(self) -> List:
|
||||||
|
"""
|
||||||
|
get full node list
|
||||||
|
"""
|
||||||
|
return self._node_list
|
||||||
|
|
||||||
|
def update_node_list(self, node_list: List) -> None:
|
||||||
|
"""
|
||||||
|
update node list, reset node dict
|
||||||
|
"""
|
||||||
|
self._node_list = node_list
|
||||||
|
self._set_node_dict()
|
||||||
|
|
||||||
|
|
||||||
def get_logger() -> Any:
|
def get_logger() -> Any:
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
|
||||||
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
||||||
return True
|
return True
|
||||||
if "getitem" in node.name:
|
if "getitem" in node.name:
|
||||||
|
if get_node_shape(node) is not None:
|
||||||
|
return False
|
||||||
node_args = flat_list(node.args[1:])
|
node_args = flat_list(node.args[1:])
|
||||||
for node_arg in node_args:
|
for node_arg in node_args:
|
||||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||||
|
@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_node_shape(node: Node) -> List:
|
def get_node_shape(node: Node) -> List:
|
||||||
|
if get_node_name(node) == "split":
|
||||||
|
return node.meta["tensor_meta"][0].shape
|
||||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||||
return node.meta["tensor_meta"].shape
|
return node.meta["tensor_meta"].shape
|
||||||
return None
|
return None
|
||||||
|
@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
||||||
return is_non_compute_node_except_placeholder(node)
|
return is_non_compute_node_except_placeholder(node)
|
||||||
|
|
||||||
|
|
||||||
def find_idx_by_name(name: str, nodes_list: List) -> int:
|
def find_node_idx(name: str, nodes_list: List) -> int:
|
||||||
for idx, node in enumerate(nodes_list):
|
for idx, node in enumerate(nodes_list):
|
||||||
if node.name == name:
|
if node.name == name:
|
||||||
return idx
|
return idx
|
||||||
|
@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return node_name
|
return node_name
|
||||||
|
|
||||||
|
|
||||||
|
def find_tensor_node(node_list: List[Node]) -> List[Node]:
|
||||||
|
"""
|
||||||
|
find tensor nodes from a node list
|
||||||
|
"""
|
||||||
|
out = []
|
||||||
|
for node in node_list:
|
||||||
|
if get_node_shape(node) is not None:
|
||||||
|
out.append(node)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
||||||
|
"""
|
||||||
|
find tensor and shape nodes from a node list
|
||||||
|
"""
|
||||||
|
out = []
|
||||||
|
for node in node_list:
|
||||||
|
if get_node_shape(node) is not None:
|
||||||
|
out.append(node)
|
||||||
|
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
||||||
|
node.meta['fwd_out'][0], int):
|
||||||
|
out.append(node)
|
||||||
|
return out
|
||||||
|
|
|
@ -23,6 +23,7 @@ def assert_codegen_run(
|
||||||
concrete_args: List = None,
|
concrete_args: List = None,
|
||||||
max_memory: int = None,
|
max_memory: int = None,
|
||||||
print_mem: bool = False,
|
print_mem: bool = False,
|
||||||
|
print_est_mem: bool = False,
|
||||||
print_progress: bool = False,
|
print_progress: bool = False,
|
||||||
print_code: bool = False,
|
print_code: bool = False,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
|
@ -41,7 +42,7 @@ def assert_codegen_run(
|
||||||
codegen = AutoChunkCodeGen(
|
codegen = AutoChunkCodeGen(
|
||||||
meta_graph,
|
meta_graph,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_mem=print_mem,
|
print_mem=print_est_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
chunks = codegen.chunk_infos
|
chunks = codegen.chunk_infos
|
||||||
|
@ -61,13 +62,20 @@ def assert_codegen_run(
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
if print_code:
|
if print_code:
|
||||||
print(code)
|
print(code)
|
||||||
assert "chunk_result = None; chunk_size = None;" in code
|
assert "chunk_size = None; " in code
|
||||||
|
|
||||||
# assert result
|
# assert result
|
||||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||||
model.cuda()
|
model.cuda()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out_gm = gm(*inputs)
|
if print_mem:
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||||
|
if print_mem:
|
||||||
|
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||||
out_model = model(*inputs)
|
out_model = model(*inputs)
|
||||||
out_gm = flat_list(out_gm)
|
out_gm = flat_list(out_gm)
|
||||||
out_model = flat_list(out_model)
|
out_model = flat_list(out_model)
|
||||||
|
@ -85,9 +93,10 @@ def run_test(
|
||||||
max_memory: int,
|
max_memory: int,
|
||||||
get_model: Any,
|
get_model: Any,
|
||||||
get_data: Any,
|
get_data: Any,
|
||||||
print_code: bool,
|
print_code: bool = False,
|
||||||
print_mem: bool,
|
print_mem: bool = False,
|
||||||
print_progress: bool,
|
print_est_mem: bool = False,
|
||||||
|
print_progress: bool = False,
|
||||||
get_chunk_target: Any = None,
|
get_chunk_target: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# launch colossalai
|
# launch colossalai
|
||||||
|
@ -110,6 +119,7 @@ def run_test(
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_code=print_code,
|
print_code=print_code,
|
||||||
print_mem=print_mem,
|
print_mem=print_mem,
|
||||||
|
print_est_mem=print_est_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||||
|
|
||||||
def get_chunk_target() -> Dict:
|
def get_chunk_target() -> Dict:
|
||||||
return {
|
return {
|
||||||
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
|
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
|
||||||
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
|
(25, 50)],
|
||||||
24: [(118, 123)],
|
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
|
||||||
|
24: [(120, 123)],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
get_chunk_target=get_chunk_target,
|
get_chunk_target=get_chunk_target,
|
||||||
print_code=False,
|
|
||||||
print_mem=False,
|
|
||||||
print_progress=False,
|
|
||||||
)
|
)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -86,10 +84,12 @@ if __name__ == "__main__":
|
||||||
run_test(
|
run_test(
|
||||||
rank=0,
|
rank=0,
|
||||||
data_args=(32, 64),
|
data_args=(32, 64),
|
||||||
max_memory=20,
|
max_memory=24,
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
|
get_chunk_target=get_chunk_target,
|
||||||
print_code=False,
|
print_code=False,
|
||||||
print_mem=False,
|
print_mem=False,
|
||||||
|
print_est_mem=False,
|
||||||
print_progress=False,
|
print_progress=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
print_code=False,
|
|
||||||
print_mem=False,
|
|
||||||
print_progress=False,
|
|
||||||
)
|
)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -81,7 +78,7 @@ if __name__ == "__main__":
|
||||||
run_test(
|
run_test(
|
||||||
rank=0,
|
rank=0,
|
||||||
data_args=(32, 64),
|
data_args=(32, 64),
|
||||||
max_memory=20,
|
max_memory=None,
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
print_code=False,
|
print_code=False,
|
||||||
|
|
|
@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||||
|
|
||||||
def get_chunk_target() -> Dict:
|
def get_chunk_target() -> Dict:
|
||||||
return {
|
return {
|
||||||
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
|
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
|
||||||
(33, 46)],
|
(36, 46)],
|
||||||
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
|
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
|
||||||
24: [(126, 131)],
|
24: [(128, 131)],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
print_code=False,
|
get_chunk_target=get_chunk_target,
|
||||||
print_mem=False,
|
|
||||||
print_progress=False,
|
|
||||||
)
|
)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -86,7 +84,7 @@ if __name__ == "__main__":
|
||||||
run_test(
|
run_test(
|
||||||
rank=0,
|
rank=0,
|
||||||
data_args=(32, 64),
|
data_args=(32, 64),
|
||||||
max_memory=20,
|
max_memory=None,
|
||||||
get_model=get_model,
|
get_model=get_model,
|
||||||
get_data=get_data,
|
get_data=get_data,
|
||||||
get_chunk_target=get_chunk_target,
|
get_chunk_target=get_chunk_target,
|
||||||
|
|
|
@ -17,8 +17,8 @@ from test_transformer_utils import run_test
|
||||||
|
|
||||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGTH = 256
|
SEQ_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||||
|
@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||||
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
|
@pytest.mark.parametrize("max_memory", [None, 6, 8])
|
||||||
def test_gpt(model, shape, max_memory):
|
def test_autochunk_gpt(model, shape, max_memory):
|
||||||
run_func = partial(
|
run_func = partial(
|
||||||
run_test,
|
run_test,
|
||||||
data=get_data(shape),
|
data=get_data(shape),
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
model=model,
|
model=model,
|
||||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||||
print_code=False,
|
|
||||||
print_mem=False,
|
|
||||||
print_progress=False,
|
|
||||||
)
|
)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -59,7 +56,8 @@ if __name__ == "__main__":
|
||||||
max_memory=None,
|
max_memory=None,
|
||||||
model=GPT2Model,
|
model=GPT2Model,
|
||||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||||
print_code=True,
|
print_code=False,
|
||||||
print_mem=True,
|
print_est_mem=False,
|
||||||
|
print_mem=False,
|
||||||
print_progress=False,
|
print_progress=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,6 +20,7 @@ def assert_codegen_run(
|
||||||
model: Any,
|
model: Any,
|
||||||
data: tuple,
|
data: tuple,
|
||||||
max_memory: int = None,
|
max_memory: int = None,
|
||||||
|
print_est_mem: bool = False,
|
||||||
print_mem: bool = False,
|
print_mem: bool = False,
|
||||||
print_progress: bool = False,
|
print_progress: bool = False,
|
||||||
print_code: bool = False,
|
print_code: bool = False,
|
||||||
|
@ -41,7 +42,7 @@ def assert_codegen_run(
|
||||||
codegen = AutoChunkCodeGen(
|
codegen = AutoChunkCodeGen(
|
||||||
meta_graph,
|
meta_graph,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_mem=print_mem,
|
print_mem=print_est_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
chunks = codegen.chunk_infos
|
chunks = codegen.chunk_infos
|
||||||
|
@ -61,7 +62,7 @@ def assert_codegen_run(
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
if print_code:
|
if print_code:
|
||||||
print(code)
|
print(code)
|
||||||
assert "chunk_result = None; chunk_size = None;" in code
|
assert "chunk_size = None; " in code
|
||||||
|
|
||||||
# assert result
|
# assert result
|
||||||
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||||
|
@ -69,26 +70,44 @@ def assert_codegen_run(
|
||||||
model.cuda().eval()
|
model.cuda().eval()
|
||||||
gm.eval()
|
gm.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out_gm = gm(*inputs)
|
if print_mem:
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||||
|
if print_mem:
|
||||||
|
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||||
out_model = model(*inputs)
|
out_model = model(*inputs)
|
||||||
for k in out_model.keys():
|
assert_allclose(out_model, out_gm)
|
||||||
if torch.is_tensor(out_gm[k]):
|
|
||||||
assert torch.equal(
|
|
||||||
out_model[k], out_gm[k]
|
|
||||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
|
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def assert_allclose(out_model: Any, out_gm: Any) -> None:
|
||||||
|
"""
|
||||||
|
assert allclose for out
|
||||||
|
"""
|
||||||
|
if isinstance(out_model, torch.Tensor):
|
||||||
|
assert torch.allclose(out_model, out_gm,
|
||||||
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
|
torch.abs(out_model - out_gm))
|
||||||
|
elif isinstance(out_model, dict):
|
||||||
|
for k in out_model.keys():
|
||||||
|
assert_allclose(out_model[k], out_gm[k])
|
||||||
|
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
|
||||||
|
for i, j in zip(out_model, out_gm):
|
||||||
|
assert_allclose(i, j)
|
||||||
|
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
rank: int,
|
rank: int,
|
||||||
model: Any,
|
model: Any,
|
||||||
config: Any,
|
config: Any,
|
||||||
data: tuple,
|
data: tuple,
|
||||||
max_memory: int,
|
max_memory: int,
|
||||||
print_code: bool,
|
print_code: bool = False,
|
||||||
print_mem: bool,
|
print_est_mem: bool = False,
|
||||||
print_progress: bool,
|
print_mem: bool = False,
|
||||||
|
print_progress: bool = False,
|
||||||
get_chunk_target: Any = None,
|
get_chunk_target: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
model = model(config=config)
|
model = model(config=config)
|
||||||
|
@ -108,6 +127,7 @@ def run_test(
|
||||||
data=data,
|
data=data,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_code=print_code,
|
print_code=print_code,
|
||||||
|
print_est_mem=print_est_mem,
|
||||||
print_mem=print_mem,
|
print_mem=print_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
|
@ -119,5 +139,3 @@ def run_test(
|
||||||
str(chunk_found),
|
str(chunk_found),
|
||||||
str(chunk_target),
|
str(chunk_target),
|
||||||
)
|
)
|
||||||
|
|
||||||
gpc.destroy()
|
|
||||||
|
|
Loading…
Reference in New Issue