mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update testspull/2540/head
parent
f477a14f4a
commit
05671fcb42
|
@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
|
|||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
|
||||
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||
|
@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
|
|||
return new_shape
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
"""
|
||||
Generate chunk loop start
|
||||
|
||||
|
@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
|
|||
context (str): generated str
|
||||
"""
|
||||
input_node = chunk_input[0]
|
||||
out_shape = get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
|
||||
(out_str, input_node.name, input_node.name, chunk_size))
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
|
||||
context = ""
|
||||
for i in range(len(chunk_output)):
|
||||
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||
if get_node_name(chunk_output[i]) == "split":
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||
input_node.name)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||
tensor_str = "[" + tensor_str[:-2] + "]"
|
||||
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
|
||||
else:
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
|
||||
input_node.name, input_node.name)
|
||||
|
||||
out_shape = get_node_shape(chunk_output[0])
|
||||
chunk_shape = out_shape[chunk_ouput_dim[0]]
|
||||
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
chunk_outputs: Node,
|
||||
chunk_outputs_dim: int,
|
||||
node_list: List[Node],
|
||||
) -> str:
|
||||
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
|
||||
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
|
@ -102,22 +108,13 @@ def _gen_loop_end(
|
|||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
|
||||
context = "chunk_size = None"
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
|
||||
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
|
||||
context += "\n"
|
||||
return context
|
||||
|
||||
|
@ -158,7 +155,7 @@ def _replace_ones_like(
|
|||
add chunk slice for new tensor op such as ones like
|
||||
"""
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
|
@ -169,21 +166,37 @@ def _replace_ones_like(
|
|||
return body
|
||||
|
||||
|
||||
def _replace_input_node(
|
||||
chunk_inputs: List[Node],
|
||||
def _add_node_slice(
|
||||
chunk_nodes: List[Node],
|
||||
region_idx: int,
|
||||
chunk_inputs_dim: Dict,
|
||||
chunk_nodes_dim: Dict,
|
||||
node_idx: int,
|
||||
body: List[str],
|
||||
node: Node,
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for input nodes
|
||||
"""
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
|
||||
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
|
||||
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
|
||||
# inputs node
|
||||
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
|
||||
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||
# outputs node
|
||||
else:
|
||||
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||
get_node_shape(chunk_node))
|
||||
if get_node_name(chunk_node) == "split":
|
||||
split_chunk_slice = ""
|
||||
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||
split_chunk_slice = split_chunk_slice[:-2]
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
|
||||
else:
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
|
@ -222,7 +235,8 @@ def emit_code_with_chunk(
|
|||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||
|
@ -248,7 +262,9 @@ def emit_code_with_chunk(
|
|||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
|
||||
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
||||
# replace output var with chunk var
|
||||
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
||||
# ones like
|
||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# reassgin reshape size
|
||||
|
@ -263,13 +279,8 @@ def emit_code_with_chunk(
|
|||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
))
|
||||
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
|
||||
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
|
|
@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
|
|||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
||||
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
|
@ -14,8 +14,8 @@ class EstimateMemory(object):
|
|||
Estimate memory with chunk
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||
self.node_mgr = node_mgr
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
x = x.meta["tensor_meta"]
|
||||
|
@ -78,7 +78,7 @@ class EstimateMemory(object):
|
|||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
||||
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
|
@ -212,7 +212,7 @@ class EstimateMemory(object):
|
|||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
|
@ -221,7 +221,7 @@ class EstimateMemory(object):
|
|||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .trace_indice import TraceIndice
|
||||
from .utils import find_idx_by_name
|
||||
from .utils import NodeMgr
|
||||
|
||||
|
||||
class ReorderGraph(object):
|
||||
|
@ -7,31 +7,27 @@ class ReorderGraph(object):
|
|||
Reorder node list and indice trace list
|
||||
"""
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.all_reorder_map = {
|
||||
i: i for i in range(len(self.trace_indice.indice_trace_list))
|
||||
}
|
||||
self.node_mgr = node_mgr
|
||||
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
|
||||
|
||||
def _get_reorder_map(self, chunk_info):
|
||||
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
||||
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
|
||||
|
||||
chunk_region_start = chunk_info["region"][0]
|
||||
chunk_region_end = chunk_info["region"][1]
|
||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||
chunk_prepose_nodes_idx = [
|
||||
find_idx_by_name(i.name, self.trace_indice.node_list)
|
||||
for i in chunk_prepose_nodes
|
||||
]
|
||||
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
|
||||
# put prepose nodes ahead
|
||||
for idx, n in enumerate(chunk_prepose_nodes):
|
||||
n_idx = chunk_prepose_nodes_idx[idx]
|
||||
reorder_map[n_idx] = chunk_region_start + idx
|
||||
# put other nodes after prepose nodes
|
||||
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
|
||||
if n in chunk_prepose_nodes:
|
||||
continue
|
||||
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
||||
n_idx = self.node_mgr.find_node_idx(n)
|
||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||
reorder_map[n_idx] = n_idx + pos
|
||||
|
||||
|
@ -44,7 +40,7 @@ class ReorderGraph(object):
|
|||
chunk_info["region"][1],
|
||||
)
|
||||
new_inputs_dim = []
|
||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||
for _, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||
new_input_dim = {}
|
||||
for k, v in input_dim.items():
|
||||
new_input_dim[reorder_map[k]] = v
|
||||
|
@ -57,16 +53,14 @@ class ReorderGraph(object):
|
|||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||
|
||||
def _reorder_self_node_list(self, reorder_map):
|
||||
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
||||
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
||||
self.trace_indice.node_list = new_node_list
|
||||
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
|
||||
self.node_mgr.update_node_list(new_node_list)
|
||||
|
||||
def _reorder_idx_trace(self, reorder_map):
|
||||
# reorder list
|
||||
new_idx_trace_list = [
|
||||
None for _ in range(len(self.trace_indice.indice_trace_list))
|
||||
]
|
||||
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||
|
|
|
@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
|
|||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
NodeMgr,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
get_logger,
|
||||
get_node_shape,
|
||||
|
@ -49,15 +50,17 @@ class SearchChunk(object):
|
|||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.print_progress = print_progress
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.node_mgr = NodeMgr(gm)
|
||||
self.trace_indice = TraceIndice(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory(self.node_mgr)
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
self.reorder_graph,
|
||||
self.node_mgr,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
|
@ -67,7 +70,7 @@ class SearchChunk(object):
|
|||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
|
@ -100,7 +103,7 @@ class SearchChunk(object):
|
|||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
for idx, n in enumerate(self.node_mgr.get_node_list()):
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
@ -164,6 +167,44 @@ class SearchChunk(object):
|
|||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_mgr.get_node_by_idx(end_idx)
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
|
||||
end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
@ -178,7 +219,7 @@ class SearchChunk(object):
|
|||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
for _, n in enumerate(self.node_mgr.get_node_list()):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||
|
@ -188,11 +229,11 @@ class SearchChunk(object):
|
|||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||
self.node_mgr.get_node_by_idx(end_idx)):
|
||||
continue
|
||||
# select free dim
|
||||
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
@ -254,7 +295,7 @@ class SearchChunk(object):
|
|||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
|
@ -267,7 +308,7 @@ class SearchChunk(object):
|
|||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
|
@ -277,5 +318,7 @@ class SearchChunk(object):
|
|||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||
chunk_infos,
|
||||
print_mem=True)
|
||||
return chunk_infos
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import is_non_compute_node
|
||||
from .utils import NodeMgr, is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
|
@ -11,11 +11,13 @@ class SelectChunk(object):
|
|||
trace_indice: TraceIndice,
|
||||
estimate_memory: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
node_mgr: NodeMgr,
|
||||
max_memory=None,
|
||||
):
|
||||
self.trace_indice = trace_indice
|
||||
self.estimate_memory = estimate_memory
|
||||
self.reorder_graph = reorder_graph
|
||||
self.node_mgr = node_mgr
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
|
@ -68,7 +70,7 @@ class SelectChunk(object):
|
|||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
|
@ -134,7 +136,7 @@ class SelectChunk(object):
|
|||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.trace_indice.node_list[start:end + 1]:
|
||||
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
|
@ -161,7 +163,7 @@ class SelectChunk(object):
|
|||
regions_dict_list = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
|
|
|
@ -4,9 +4,10 @@ from torch.fx.node import Node
|
|||
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
NodeMgr,
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
find_tensor_shape_node,
|
||||
flat_list,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
|
@ -16,8 +17,9 @@ from .utils import (
|
|||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.node_mgr = node_mgr
|
||||
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
|
@ -31,7 +33,8 @@ class TraceFlow(object):
|
|||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
# we use start_node_idx instead of real chunk index
|
||||
start_node_idx = self.node_mgr.find_node_idx(start_node)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||
|
@ -39,7 +42,7 @@ class TraceFlow(object):
|
|||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
if node_idx < start_node_idx:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
@ -61,29 +64,12 @@ class TraceFlow(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
cur_node: Node,
|
||||
cur_node_dim: int,
|
||||
cur_node_compute: Dict,
|
||||
cur_node_source: Dict,
|
||||
|
@ -109,7 +95,7 @@ class TraceFlow(object):
|
|||
Returns:
|
||||
bool: True if this node can be added to the flow, vice versa.
|
||||
"""
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
arg_idx = self.node_mgr.find_node_idx(arg_node)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
@ -126,6 +112,11 @@ class TraceFlow(object):
|
|||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
# chunk shape should equal cur node
|
||||
elif get_node_shape(arg_node)[arg_dim] != 1:
|
||||
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
|
||||
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
|
||||
return False
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
|
@ -150,7 +141,7 @@ class TraceFlow(object):
|
|||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
|
@ -178,6 +169,7 @@ class TraceFlow(object):
|
|||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
|
@ -194,7 +186,7 @@ class TraceFlow(object):
|
|||
for arg in arg_list:
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
|
@ -232,7 +224,7 @@ class TraceFlow(object):
|
|||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
input_node_idx = self.node_mgr.find_node_idx(input_node)
|
||||
for user in input_node.users.keys():
|
||||
# skip non compute
|
||||
if is_non_compute_node(user):
|
||||
|
@ -240,7 +232,7 @@ class TraceFlow(object):
|
|||
# untraced node, mostly non compute
|
||||
if user not in all_node_info:
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
user_idx = self.node_mgr.find_node_idx(user)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
|
@ -262,7 +254,7 @@ class TraceFlow(object):
|
|||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
|
||||
"""
|
||||
get all useless nodes in chunk region and prepose them
|
||||
|
||||
|
@ -279,8 +271,11 @@ class TraceFlow(object):
|
|||
for node, node_info in all_node_info.items():
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
|
||||
if node not in all_node_info and node not in chunk_info["outputs"]:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
key=lambda x: self.node_mgr.find_node_idx(x),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
|
@ -303,8 +298,7 @@ class TraceFlow(object):
|
|||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
||||
end_idx):
|
||||
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
|
@ -328,13 +322,12 @@ class TraceFlow(object):
|
|||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
||||
|
||||
return prepose_nodes
|
||||
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
||||
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
|
@ -345,34 +338,41 @@ class TraceFlow(object):
|
|||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs": [],
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"inputs_dim": [],
|
||||
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
|
||||
"outputs_non_tensor": {},
|
||||
"outputs_dim": [end_dim],
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# find chunk info for other outputs
|
||||
if len(find_tensor_shape_node(outputs)) > 1:
|
||||
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
|
||||
if chunk_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
chunk_info["inputs"] = inputs
|
||||
chunk_info["inputs_dim"] = inputs_dim
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
|
@ -382,6 +382,63 @@ class TraceFlow(object):
|
|||
|
||||
return chunk_info
|
||||
|
||||
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
|
||||
chunk_info: Dict):
|
||||
start_node = self.node_mgr.get_node_by_idx(start_idx)
|
||||
# loop all outputs
|
||||
for output in outputs:
|
||||
output_legal = False
|
||||
output_idx = self.node_mgr.find_node_idx(output)
|
||||
# skip the origin output
|
||||
if output_idx == end_idx:
|
||||
continue
|
||||
# skip non tensor
|
||||
if get_node_shape(output) is None:
|
||||
# log shape tensor
|
||||
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
|
||||
continue
|
||||
# loop every dim of outputs, try to find a legal one
|
||||
for output_dim in range(len(get_node_shape(output))):
|
||||
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
|
||||
continue
|
||||
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
|
||||
if new_all_node_info is None:
|
||||
continue
|
||||
# check node info legal
|
||||
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
|
||||
output_legal = True
|
||||
break
|
||||
# not legal
|
||||
if output_legal == False:
|
||||
return None
|
||||
return chunk_info
|
||||
|
||||
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
|
||||
"""
|
||||
check if there is conflict between new node info and old chunk info. If not, update old chunk info
|
||||
"""
|
||||
# check if conflict
|
||||
overlap_flag = False
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
overlap_flag = True
|
||||
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
|
||||
return False
|
||||
# if no overlap, we just consider them as prepose nodes, instead of new output
|
||||
if overlap_flag == False:
|
||||
return True
|
||||
# update chunk info
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
|
||||
else:
|
||||
chunk_info["node_chunk_dim"][k] = v
|
||||
chunk_info["outputs"].append(output)
|
||||
chunk_info["outputs_dim"].append(output_dim)
|
||||
return True
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
"""
|
||||
Some shape args in reshape may have changed due to chunk
|
||||
|
@ -389,10 +446,17 @@ class TraceFlow(object):
|
|||
"""
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
|
||||
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
|
||||
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
||||
if node in chunk_info["args"]["prepose_nodes"]:
|
||||
continue
|
||||
if node.args[0] in chunk_info["inputs_non_chunk"]:
|
||||
continue
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
|
||||
reshape_args[0].meta['fwd_out']) > 1:
|
||||
continue
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
|
@ -409,45 +473,8 @@ class TraceFlow(object):
|
|||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
|
|
|
@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
|
|||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import (
|
||||
find_first_tensor_arg,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_module_node_name,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
)
|
||||
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
|
@ -35,8 +28,8 @@ class TraceIndice(object):
|
|||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List[Node]) -> None:
|
||||
self.node_list = node_list
|
||||
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||
self.node_mgr = node_mgr
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
|
@ -45,7 +38,7 @@ class TraceIndice(object):
|
|||
|
||||
def _init_indice_trace_list(self) -> List:
|
||||
indice_trace_list = []
|
||||
for n in self.node_list:
|
||||
for n in self.node_mgr.get_node_list():
|
||||
if get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||
|
@ -99,7 +92,7 @@ class TraceIndice(object):
|
|||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
||||
node_from_idx = self.node_mgr.find_node_idx(node_from)
|
||||
if init:
|
||||
node_to_trace_source[node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
|
@ -200,7 +193,7 @@ class TraceIndice(object):
|
|||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
|
@ -214,7 +207,7 @@ class TraceIndice(object):
|
|||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
|
@ -227,7 +220,7 @@ class TraceIndice(object):
|
|||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
return self.indice_trace_list[node_idx]["indice"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node: Node) -> List:
|
||||
|
@ -239,7 +232,7 @@ class TraceIndice(object):
|
|||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
|
||||
|
@ -454,8 +447,6 @@ class TraceIndice(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
for _ in range(len(get_node_shape(node.args[0]))):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
dim_idx = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, dim_idx)
|
||||
|
@ -702,21 +693,20 @@ class TraceIndice(object):
|
|||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
and view_dict["dim_from"] == dim_to):
|
||||
# inheirt indice from current node
|
||||
for dim_to_i in dim_to:
|
||||
for dim_from_i in dim_from:
|
||||
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# inherid indice from input node of last view
|
||||
for dim_to_i in dim_to:
|
||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inherit computation
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self._mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
"idx_from": [origin_trace[i] for i in dim_from],
|
||||
|
@ -742,7 +732,7 @@ class TraceIndice(object):
|
|||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
||||
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
|
@ -758,7 +748,7 @@ class TraceIndice(object):
|
|||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self) -> None:
|
||||
for idx, node in enumerate(self.node_list):
|
||||
for idx, node in enumerate(self.node_mgr.get_node_list()):
|
||||
node_name = get_node_name(node)
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_indice(node, idx)
|
||||
|
|
|
@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
|
|||
logger = get_dist_logger()
|
||||
|
||||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, gm) -> None:
|
||||
self._node_list = list(gm.graph.nodes)
|
||||
self._node_dict = {}
|
||||
self._set_node_dict()
|
||||
|
||||
def _set_node_dict(self) -> None:
|
||||
"""
|
||||
create a dict {node_name: node_idx}
|
||||
"""
|
||||
self._node_dict.clear()
|
||||
for idx, node in enumerate(self._node_list):
|
||||
self._node_dict[node.name] = idx
|
||||
|
||||
def find_node_idx(self, node: Node) -> int:
|
||||
"""
|
||||
find node's index
|
||||
"""
|
||||
return self._node_dict[node.name]
|
||||
|
||||
def find_node_idx_by_name(self, node_name: str) -> int:
|
||||
"""
|
||||
find node's index
|
||||
"""
|
||||
return self._node_dict[node_name]
|
||||
|
||||
def get_node_by_idx(self, idx: int) -> Node:
|
||||
"""
|
||||
get a node by index
|
||||
"""
|
||||
return self._node_list[idx]
|
||||
|
||||
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
|
||||
"""
|
||||
get a slice of node by index
|
||||
"""
|
||||
return self._node_list[start:end]
|
||||
|
||||
def get_node_list(self) -> List:
|
||||
"""
|
||||
get full node list
|
||||
"""
|
||||
return self._node_list
|
||||
|
||||
def update_node_list(self, node_list: List) -> None:
|
||||
"""
|
||||
update node list, reset node dict
|
||||
"""
|
||||
self._node_list = node_list
|
||||
self._set_node_dict()
|
||||
|
||||
|
||||
def get_logger() -> Any:
|
||||
return logger
|
||||
|
||||
|
@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
|
|||
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
if get_node_shape(node) is not None:
|
||||
return False
|
||||
node_args = flat_list(node.args[1:])
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
|
@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
|
|||
|
||||
|
||||
def get_node_shape(node: Node) -> List:
|
||||
if get_node_name(node) == "split":
|
||||
return node.meta["tensor_meta"][0].shape
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
|
@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
|||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name: str, nodes_list: List) -> int:
|
||||
def find_node_idx(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
|
@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
|
|||
else:
|
||||
break
|
||||
return node_name
|
||||
|
||||
|
||||
def find_tensor_node(node_list: List[Node]) -> List[Node]:
|
||||
"""
|
||||
find tensor nodes from a node list
|
||||
"""
|
||||
out = []
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
return out
|
||||
|
||||
|
||||
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
||||
"""
|
||||
find tensor and shape nodes from a node list
|
||||
"""
|
||||
out = []
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
||||
node.meta['fwd_out'][0], int):
|
||||
out.append(node)
|
||||
return out
|
||||
|
|
|
@ -23,6 +23,7 @@ def assert_codegen_run(
|
|||
concrete_args: List = None,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
|
@ -41,7 +42,7 @@ def assert_codegen_run(
|
|||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
@ -61,13 +62,20 @@ def assert_codegen_run(
|
|||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
assert "chunk_size = None; " in code
|
||||
|
||||
# assert result
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||
model.cuda()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
if print_mem:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||
out_model = model(*inputs)
|
||||
out_gm = flat_list(out_gm)
|
||||
out_model = flat_list(out_model)
|
||||
|
@ -85,9 +93,10 @@ def run_test(
|
|||
max_memory: int,
|
||||
get_model: Any,
|
||||
get_data: Any,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
print_code: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
# launch colossalai
|
||||
|
@ -110,6 +119,7 @@ def run_test(
|
|||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_est_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
|
|
|
@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
|
||||
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
|
||||
24: [(118, 123)],
|
||||
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
|
||||
(25, 50)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
|
||||
24: [(120, 123)],
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
|
|||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
@ -86,10 +84,12 @@ if __name__ == "__main__":
|
|||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=24,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_est_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
|
|
|
@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
|
|||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
@ -81,7 +78,7 @@ if __name__ == "__main__":
|
|||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
|
|
|
@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
|
||||
(33, 46)],
|
||||
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
|
||||
24: [(126, 131)],
|
||||
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
|
||||
(36, 46)],
|
||||
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
|
||||
24: [(128, 131)],
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
|
|||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
get_chunk_target=get_chunk_target,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
@ -86,7 +84,7 @@ if __name__ == "__main__":
|
|||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
|
|
|
@ -17,8 +17,8 @@ from test_transformer_utils import run_test
|
|||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 256
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 512
|
||||
|
||||
|
||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
|
@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
|||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
|
||||
def test_gpt(model, shape, max_memory):
|
||||
@pytest.mark.parametrize("max_memory", [None, 6, 8])
|
||||
def test_autochunk_gpt(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
@ -59,7 +56,8 @@ if __name__ == "__main__":
|
|||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=True,
|
||||
print_mem=True,
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ def assert_codegen_run(
|
|||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int = None,
|
||||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
|
@ -41,7 +42,7 @@ def assert_codegen_run(
|
|||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
@ -61,7 +62,7 @@ def assert_codegen_run(
|
|||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
assert "chunk_size = None; " in code
|
||||
|
||||
# assert result
|
||||
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
|
@ -69,26 +70,44 @@ def assert_codegen_run(
|
|||
model.cuda().eval()
|
||||
gm.eval()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
if print_mem:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||
out_model = model(*inputs)
|
||||
for k in out_model.keys():
|
||||
if torch.is_tensor(out_gm[k]):
|
||||
assert torch.equal(
|
||||
out_model[k], out_gm[k]
|
||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
|
||||
|
||||
assert_allclose(out_model, out_gm)
|
||||
return chunks
|
||||
|
||||
|
||||
def assert_allclose(out_model: Any, out_gm: Any) -> None:
|
||||
"""
|
||||
assert allclose for out
|
||||
"""
|
||||
if isinstance(out_model, torch.Tensor):
|
||||
assert torch.allclose(out_model, out_gm,
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(out_model - out_gm))
|
||||
elif isinstance(out_model, dict):
|
||||
for k in out_model.keys():
|
||||
assert_allclose(out_model[k], out_gm[k])
|
||||
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
|
||||
for i, j in zip(out_model, out_gm):
|
||||
assert_allclose(i, j)
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
model: Any,
|
||||
config: Any,
|
||||
data: tuple,
|
||||
max_memory: int,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
print_code: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
model = model(config=config)
|
||||
|
@ -108,6 +127,7 @@ def run_test(
|
|||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
@ -119,5 +139,3 @@ def run_test(
|
|||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
||||
|
||||
gpc.destroy()
|
||||
|
|
Loading…
Reference in New Issue