[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
pull/2540/head
oahzxl 2023-02-01 13:18:51 +08:00 committed by GitHub
parent f477a14f4a
commit 05671fcb42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 428 additions and 258 deletions

View File

@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str: def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
""" """
Generate chunk loop start Generate chunk loop start
@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
context (str): generated str context (str): generated str
""" """
input_node = chunk_input[0] input_node = chunk_input[0]
out_shape = get_node_shape(chunk_output)
out_str = str(list(out_shape)) context = ""
context = ( for i in range(len(chunk_output)):
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % shape_str = str(list(get_node_shape(chunk_output[i])))
(out_str, input_node.name, input_node.name, chunk_size)) if get_node_name(chunk_output[i]) == "split":
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_ouput_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context return context
def _gen_loop_end( def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_inputs: List[Node], chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
chunk_non_compute_inputs: List[Node],
chunk_outputs: Node,
chunk_outputs_dim: int,
node_list: List[Node],
) -> str:
""" """
Generate chunk loop end Generate chunk loop end
@ -102,22 +108,13 @@ def _gen_loop_end(
Returns: Returns:
context (str): generated str context (str): generated str
""" """
chunk_outputs_name = chunk_outputs.name context = "chunk_size = None"
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice,
chunk_outputs_name,
chunk_outputs_name,
)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
# determine if its the last use for chunk input # determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs: for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]): if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
context += "; %s = None" % chunk_input.name context += "; %s = None" % chunk_input.name
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
context += "\n" context += "\n"
return context return context
@ -158,7 +155,7 @@ def _replace_ones_like(
add chunk slice for new tensor op such as ones like add chunk slice for new tensor op such as ones like
""" """
if "ones_like" in node.name: if "ones_like" in node.name:
meta_node = search_chunk.trace_indice.node_list[node_idx] meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1: if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0] source_node = meta_node.args[0].args[0]
@ -169,21 +166,37 @@ def _replace_ones_like(
return body return body
def _replace_input_node( def _add_node_slice(
chunk_inputs: List[Node], chunk_nodes: List[Node],
region_idx: int, region_idx: int,
chunk_inputs_dim: Dict, chunk_nodes_dim: Dict,
node_idx: int, node_idx: int,
body: List[str], body: List[str],
node: Node,
) -> List[str]: ) -> List[str]:
""" """
add chunk slice for input nodes add chunk slice for input nodes
""" """
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): # inputs node
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
if idx == node_idx: if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node)) chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice) body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
if get_node_name(chunk_node) == "split":
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
else:
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
return body return body
@ -222,7 +235,8 @@ def emit_code_with_chunk(
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs # chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = search_chunk.reorder_graph.reorder_node_list(node_list) node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
@ -248,7 +262,9 @@ def emit_code_with_chunk(
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
# replace input var with chunk var # replace input var with chunk var
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body) body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# ones like # ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size # reassgin reshape size
@ -263,13 +279,8 @@ def emit_code_with_chunk(
# generate chunk region end # generate chunk region end
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append( body.append(
_gen_loop_end( _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_inputs[region_idx], chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
node_list,
))
within_chunk_region = False within_chunk_region = False
node_idx += 1 node_idx += 1

View File

@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size from colossalai.fx.profiler import activation_size, parameter_size
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
class EstimateMemory(object): class EstimateMemory(object):
@ -14,8 +14,8 @@ class EstimateMemory(object):
Estimate memory with chunk Estimate memory with chunk
""" """
def __init__(self) -> None: def __init__(self, node_mgr: NodeMgr) -> None:
pass self.node_mgr = node_mgr
def _get_meta_node_size(self, x): def _get_meta_node_size(self, x):
x = x.meta["tensor_meta"] x = x.meta["tensor_meta"]
@ -78,7 +78,7 @@ class EstimateMemory(object):
nodes_to_delete = [] nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk: for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys() chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users] chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx): if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete: if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input) nodes_to_delete.append(chunk_input)
@ -212,7 +212,7 @@ class EstimateMemory(object):
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i] ] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
@ -221,7 +221,7 @@ class EstimateMemory(object):
if use_chunk and idx in chunk_starts: if use_chunk and idx in chunk_starts:
chunk_within = True chunk_within = True
chunk_region_idx = chunk_starts.index(idx) chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
# determine chunk ratio for current node # determine chunk ratio for current node
if chunk_within: if chunk_within:

View File

@ -1,5 +1,5 @@
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import find_idx_by_name from .utils import NodeMgr
class ReorderGraph(object): class ReorderGraph(object):
@ -7,31 +7,27 @@ class ReorderGraph(object):
Reorder node list and indice trace list Reorder node list and indice trace list
""" """
def __init__(self, trace_indice: TraceIndice) -> None: def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice self.trace_indice = trace_indice
self.all_reorder_map = { self.node_mgr = node_mgr
i: i for i in range(len(self.trace_indice.indice_trace_list)) self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
}
def _get_reorder_map(self, chunk_info): def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))} reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
chunk_region_start = chunk_info["region"][0] chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1] chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [ chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
find_idx_by_name(i.name, self.trace_indice.node_list)
for i in chunk_prepose_nodes
]
# put prepose nodes ahead # put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes): for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx] n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes # put other nodes after prepose nodes
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]: for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
if n in chunk_prepose_nodes: if n in chunk_prepose_nodes:
continue continue
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list) n_idx = self.node_mgr.find_node_idx(n)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos reorder_map[n_idx] = n_idx + pos
@ -44,7 +40,7 @@ class ReorderGraph(object):
chunk_info["region"][1], chunk_info["region"][1],
) )
new_inputs_dim = [] new_inputs_dim = []
for idx, input_dim in enumerate(chunk_info["inputs_dim"]): for _, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {} new_input_dim = {}
for k, v in input_dim.items(): for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v new_input_dim[reorder_map[k]] = v
@ -57,16 +53,14 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx] self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map): def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.trace_indice.node_list))] new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.trace_indice.node_list[old_idx] new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
self.trace_indice.node_list = new_node_list self.node_mgr.update_node_list(new_node_list)
def _reorder_idx_trace(self, reorder_map): def _reorder_idx_trace(self, reorder_map):
# reorder list # reorder list
new_idx_trace_list = [ new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
None for _ in range(len(self.trace_indice.indice_trace_list))
]
for old_idx, new_idx in reorder_map.items(): for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list self.trace_indice.indice_trace_list = new_idx_trace_list

View File

@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import ( from .utils import (
NodeMgr,
find_chunk_compute_input_and_output_nodes, find_chunk_compute_input_and_output_nodes,
get_logger, get_logger,
get_node_shape, get_node_shape,
@ -49,15 +50,17 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem self.print_mem = print_mem
self.print_progress = print_progress self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes)) self.node_mgr = NodeMgr(gm)
self.estimate_memory = EstimateMemory() self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory(self.node_mgr)
self._init_trace() self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice) self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice) self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
self.select_chunk = SelectChunk( self.select_chunk = SelectChunk(
self.trace_indice, self.trace_indice,
self.estimate_memory, self.estimate_memory,
self.reorder_graph, self.reorder_graph,
self.node_mgr,
max_memory=max_memory, max_memory=max_memory,
) )
@ -67,7 +70,7 @@ class SearchChunk(object):
reduce the computation complexity of trace_indice reduce the computation complexity of trace_indice
""" """
# find all max ranges # find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list) active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
cur_node_idx = len(self._get_free_var_idx()) cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = [] max_chunk_region_list = []
while True: while True:
@ -100,7 +103,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars free_var_idx (List): all indexs of free vars
""" """
free_var_idx = [] free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list): for idx, n in enumerate(self.node_mgr.get_node_list()):
if n.op == "placeholder" and get_node_shape(n) is not None: if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx) free_var_idx.append(idx)
return free_var_idx return free_var_idx
@ -164,6 +167,44 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1 chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
""" """
Search every possible region within the max chunk region. Search every possible region within the max chunk region.
@ -178,7 +219,7 @@ class SearchChunk(object):
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list): for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {} cur_trace = {}
for arg in n.args: for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg): if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
@ -188,11 +229,11 @@ class SearchChunk(object):
for start_idx in range(max_chunk_region[0], peak_node + 1): for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes # skip non compute nodes
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node( if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.trace_indice.node_list[end_idx]): self.node_mgr.get_node_by_idx(end_idx)):
continue continue
# select free dim # select free dim
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx) chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0: if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info) possible_chunk_region.extend(chunk_info)
return possible_chunk_region return possible_chunk_region
@ -254,7 +295,7 @@ class SearchChunk(object):
init_mem_peak, init_mem_peak,
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list) ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak mem_peak = init_mem_peak
while True: while True:
@ -267,7 +308,7 @@ class SearchChunk(object):
mem_peak, mem_peak,
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos) ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
if self.print_progress: if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
@ -277,5 +318,7 @@ class SearchChunk(object):
break break
if self.print_mem: if self.print_mem:
self.print_mem = False self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True) self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
return chunk_infos return chunk_infos

View File

@ -1,7 +1,7 @@
from .estimate_memory import EstimateMemory from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import is_non_compute_node from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object): class SelectChunk(object):
@ -11,11 +11,13 @@ class SelectChunk(object):
trace_indice: TraceIndice, trace_indice: TraceIndice,
estimate_memory: EstimateMemory, estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph, reorder_graph: ReorderGraph,
node_mgr: NodeMgr,
max_memory=None, max_memory=None,
): ):
self.trace_indice = trace_indice self.trace_indice = trace_indice
self.estimate_memory = estimate_memory self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph self.reorder_graph = reorder_graph
self.node_mgr = node_mgr
if max_memory is not None: if max_memory is not None:
self.stratge = "fit_memory" self.stratge = "fit_memory"
self.max_memory = max_memory # MB self.max_memory = max_memory # MB
@ -68,7 +70,7 @@ class SelectChunk(object):
regions_dict = [] regions_dict = []
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
@ -134,7 +136,7 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end): def _get_compute_node_num(self, start, end):
count = 0 count = 0
for i in self.trace_indice.node_list[start:end + 1]: for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
if not is_non_compute_node(i): if not is_non_compute_node(i):
count += 1 count += 1
return count return count
@ -161,7 +163,7 @@ class SelectChunk(object):
regions_dict_list = [] regions_dict_list = []
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_region = region.copy() cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]

View File

@ -4,9 +4,10 @@ from torch.fx.node import Node
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import ( from .utils import (
NodeMgr,
find_chunk_all_input_nodes, find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes, find_chunk_compute_input_and_output_nodes,
find_idx_by_name, find_tensor_shape_node,
flat_list, flat_list,
get_node_name, get_node_name,
get_node_shape, get_node_shape,
@ -16,8 +17,9 @@ from .utils import (
class TraceFlow(object): class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None: def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice self.trace_indice = trace_indice
self.node_mgr = node_mgr
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
""" """
@ -31,7 +33,8 @@ class TraceFlow(object):
Returns: Returns:
bool: True if check pass bool: True if check pass
""" """
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) # we use start_node_idx instead of real chunk index
start_node_idx = self.node_mgr.find_node_idx(start_node)
end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim] end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
@ -39,7 +42,7 @@ class TraceFlow(object):
if node_idx == start_node_idx and start_dim in node_dim: if node_idx == start_node_idx and start_dim in node_dim:
return True return True
# it means we meet a node outside the loop, and the node is not input node # it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx: if node_idx < start_node_idx:
return False return False
return False return False
@ -61,29 +64,12 @@ class TraceFlow(object):
return False return False
return True return True
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None
def _assgin_single_node_flow( def _assgin_single_node_flow(
self, self,
arg_node: Node, arg_node: Node,
start_idx: int, start_idx: int,
end_idx: int, end_idx: int,
cur_node: Node,
cur_node_dim: int, cur_node_dim: int,
cur_node_compute: Dict, cur_node_compute: Dict,
cur_node_source: Dict, cur_node_source: Dict,
@ -109,7 +95,7 @@ class TraceFlow(object):
Returns: Returns:
bool: True if this node can be added to the flow, vice versa. bool: True if this node can be added to the flow, vice versa.
""" """
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list) arg_idx = self.node_mgr.find_node_idx(arg_node)
# arg in chunk range or be inputs # arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx): if not (start_idx <= arg_idx < end_idx):
return True return True
@ -126,6 +112,11 @@ class TraceFlow(object):
# chunk dim should be None if shape size is 1 # chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1: if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None arg_dim = None
# chunk shape should equal cur node
elif get_node_shape(arg_node)[arg_dim] != 1:
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
return False
else: else:
arg_dim = None arg_dim = None
@ -150,7 +141,7 @@ class TraceFlow(object):
return True return True
def _get_all_node_info(self, end_dim, start_idx, end_idx): def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0: while len(cur_node_list) > 0:
@ -178,6 +169,7 @@ class TraceFlow(object):
arg, arg,
start_idx, start_idx,
end_idx, end_idx,
cur_node,
cur_node_chunk_dim, cur_node_chunk_dim,
cur_node_compute, cur_node_compute,
cur_node_source, cur_node_source,
@ -194,7 +186,7 @@ class TraceFlow(object):
for arg in arg_list: for arg in arg_list:
if get_node_shape(arg) is None: if get_node_shape(arg) is None:
continue continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
continue continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"] arg_fix_dim = all_node_info[arg]["fix_dim"]
@ -232,7 +224,7 @@ class TraceFlow(object):
remove_inputs = [] remove_inputs = []
for input_node in inputs: for input_node in inputs:
input_dict = {} input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) input_node_idx = self.node_mgr.find_node_idx(input_node)
for user in input_node.users.keys(): for user in input_node.users.keys():
# skip non compute # skip non compute
if is_non_compute_node(user): if is_non_compute_node(user):
@ -240,7 +232,7 @@ class TraceFlow(object):
# untraced node, mostly non compute # untraced node, mostly non compute
if user not in all_node_info: if user not in all_node_info:
continue continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) user_idx = self.node_mgr.find_node_idx(user)
if start_idx <= user_idx <= end_idx: if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"] chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None: if chunk_dim is not None:
@ -262,7 +254,7 @@ class TraceFlow(object):
inputs.remove(i) inputs.remove(i)
return inputs, inputs_dim return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]: def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
""" """
get all useless nodes in chunk region and prepose them get all useless nodes in chunk region and prepose them
@ -279,8 +271,11 @@ class TraceFlow(object):
for node, node_info in all_node_info.items(): for node, node_info in all_node_info.items():
if node_info["chunk_dim"] is None: if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node) maybe_prepose_nodes.append(node)
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
if node not in all_node_info and node not in chunk_info["outputs"]:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort( maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True, reverse=True,
) # from last node to first node ) # from last node to first node
prepose_nodes = [] prepose_nodes = []
@ -303,8 +298,7 @@ class TraceFlow(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node): if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue continue
# out of loop # out of loop
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) < if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
end_idx):
continue continue
# compute op in loop # compute op in loop
elif cur_prepose_node_arg in all_node_info: elif cur_prepose_node_arg in all_node_info:
@ -328,13 +322,12 @@ class TraceFlow(object):
if n in maybe_prepose_nodes: if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n) maybe_prepose_nodes.remove(n)
# sort by index # sort by index
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)) prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1] chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]: for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n) chunk_node_list.remove(n)
@ -345,34 +338,41 @@ class TraceFlow(object):
return chunk_info return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim): def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1]) inputs, outputs = find_chunk_compute_input_and_output_nodes(
# only single ouput self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
if len(outputs) > 1:
return None
# get every node's chunk dim and fix dim # get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None: if all_node_info is None:
return None return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info = { chunk_info = {
"region": (start_idx, end_idx), "region": (start_idx, end_idx),
"inputs": inputs, "inputs": [],
"inputs_non_chunk": [], "inputs_non_chunk": [],
"inputs_dim": inputs_dim, "inputs_dim": [],
"outputs": outputs, "outputs": [self.node_mgr.get_node_by_idx(end_idx)],
"outputs_dim": end_dim, "outputs_non_tensor": {},
"outputs_dim": [end_dim],
"node_chunk_dim": all_node_info, "node_chunk_dim": all_node_info,
"args": {}, "args": {},
} }
# find chunk info for other outputs
if len(find_tensor_shape_node(outputs)) > 1:
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
if chunk_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info["inputs"] = inputs
chunk_info["inputs_dim"] = inputs_dim
# move useless nodes ahead of loop # move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx) self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
# find non chunk inputs # find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
@ -382,6 +382,63 @@ class TraceFlow(object):
return chunk_info return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
output_legal = False
output_idx = self.node_mgr.find_node_idx(output)
# skip the origin output
if output_idx == end_idx:
continue
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
continue
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
if new_all_node_info is None:
continue
# check node info legal
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
output_legal = True
break
# not legal
if output_legal == False:
return None
return chunk_info
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
"""
check if there is conflict between new node info and old chunk info. If not, update old chunk info
"""
# check if conflict
overlap_flag = False
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
overlap_flag = True
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
return False
# if no overlap, we just consider them as prepose nodes, instead of new output
if overlap_flag == False:
return True
# update chunk info
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
def _reassgin_reshape_size(self, chunk_info): def _reassgin_reshape_size(self, chunk_info):
""" """
Some shape args in reshape may have changed due to chunk Some shape args in reshape may have changed due to chunk
@ -389,10 +446,17 @@ class TraceFlow(object):
""" """
chunk_region = chunk_info["region"] chunk_region = chunk_info["region"]
reshape_size = {} reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
if any(i == get_node_name(node) for i in ["reshape", "view"]): if any(i == get_node_name(node) for i in ["reshape", "view"]):
if node in chunk_info["args"]["prepose_nodes"]:
continue
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:]) reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = "" new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
@ -409,44 +473,7 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool: end_idx: int) -> bool:
""" """
check if region start and end is legal check if region start and end is legal

View File

@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
from .utils import ( from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
find_first_tensor_arg,
find_idx_by_name,
flat_list,
get_module_node_name,
get_node_name,
get_node_shape,
)
class TraceIndice(object): class TraceIndice(object):
@ -35,8 +28,8 @@ class TraceIndice(object):
node_list (List) node_list (List)
""" """
def __init__(self, node_list: List[Node]) -> None: def __init__(self, node_mgr: NodeMgr) -> None:
self.node_list = node_list self.node_mgr = node_mgr
self.indice_trace_list = self._init_indice_trace_list() self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {} self.indice_view_list = {}
self.indice_count = -1 self.indice_count = -1
@ -45,7 +38,7 @@ class TraceIndice(object):
def _init_indice_trace_list(self) -> List: def _init_indice_trace_list(self) -> List:
indice_trace_list = [] indice_trace_list = []
for n in self.node_list: for n in self.node_mgr.get_node_list():
if get_node_shape(n) != None: if get_node_shape(n) != None:
cur_trace = { cur_trace = {
"indice": [None for _ in range(len(get_node_shape(n)))], "indice": [None for _ in range(len(get_node_shape(n)))],
@ -99,7 +92,7 @@ class TraceIndice(object):
node_from_trace_source = self._find_source_trace_from_node(node_from) node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim) node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to) node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = find_idx_by_name(node_from.name, self.node_list) node_from_idx = self.node_mgr.find_node_idx(node_from)
if init: if init:
node_to_trace_source[node_to_dim] = {} node_to_trace_source[node_to_dim] = {}
# add dim to cur new source # add dim to cur new source
@ -200,7 +193,7 @@ class TraceIndice(object):
idx (list): idx of the node idx (list): idx of the node
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx] node_dict = self.indice_trace_list[node_idx]
return node_dict return node_dict
@ -214,7 +207,7 @@ class TraceIndice(object):
idx (list): idx of the node idx (list): idx of the node
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx] node_dict = self.indice_trace_list[node_idx]
return node_dict["source"] return node_dict["source"]
@ -227,7 +220,7 @@ class TraceIndice(object):
Returns: Returns:
idx (list): idx of the node idx (list): idx of the node
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["indice"] return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node: Node) -> List: def _find_compute_trace_from_node(self, node: Node) -> List:
@ -239,7 +232,7 @@ class TraceIndice(object):
Returns: Returns:
compute (list): computed idx of the node. compute (list): computed idx of the node.
""" """
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["compute"] return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None: def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
@ -454,8 +447,6 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
for _ in range(len(get_node_shape(node.args[0]))):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"] dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx) self._del_dim(node_idx, dim_idx)
@ -702,21 +693,20 @@ class TraceIndice(object):
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to): and view_dict["dim_from"] == dim_to):
# inheirt indice from current node # inheirt indice from current node
for dim_to_i in dim_to: if len_diff == 1:
for dim_from_i in dim_from: if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False) self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# inherid indice from input node of last view # inherid indice from input node of last view
for dim_to_i in dim_to: for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self._mark_computation(node, node_idx, [j])
break
# log view, not used now # log view, not used now
view_dict = { view_dict = {
"idx_from": [origin_trace[i] for i in dim_from], "idx_from": [origin_trace[i] for i in dim_from],
@ -742,7 +732,7 @@ class TraceIndice(object):
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1] active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes)) active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes] active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1): for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i] trace = self.indice_trace_list[i]
# clear compute # clear compute
@ -758,7 +748,7 @@ class TraceIndice(object):
dim_source.pop(k) dim_source.pop(k)
def trace_indice(self) -> None: def trace_indice(self) -> None:
for idx, node in enumerate(self.node_list): for idx, node in enumerate(self.node_mgr.get_node_list()):
node_name = get_node_name(node) node_name = get_node_name(node)
if node.op == "placeholder": if node.op == "placeholder":
self._assign_all_indice(node, idx) self._assign_all_indice(node, idx)

View File

@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
logger = get_dist_logger() logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, gm) -> None:
self._node_list = list(gm.graph.nodes)
self._node_dict = {}
self._set_node_dict()
def _set_node_dict(self) -> None:
"""
create a dict {node_name: node_idx}
"""
self._node_dict.clear()
for idx, node in enumerate(self._node_list):
self._node_dict[node.name] = idx
def find_node_idx(self, node: Node) -> int:
"""
find node's index
"""
return self._node_dict[node.name]
def find_node_idx_by_name(self, node_name: str) -> int:
"""
find node's index
"""
return self._node_dict[node_name]
def get_node_by_idx(self, idx: int) -> Node:
"""
get a node by index
"""
return self._node_list[idx]
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
"""
get a slice of node by index
"""
return self._node_list[start:end]
def get_node_list(self) -> List:
"""
get full node list
"""
return self._node_list
def update_node_list(self, node_list: List) -> None:
"""
update node list, reset node dict
"""
self._node_list = node_list
self._set_node_dict()
def get_logger() -> Any: def get_logger() -> Any:
return logger return logger
@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME): if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True return True
if "getitem" in node.name: if "getitem" in node.name:
if get_node_shape(node) is not None:
return False
node_args = flat_list(node.args[1:]) node_args = flat_list(node.args[1:])
for node_arg in node_args: for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]): if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
def get_node_shape(node: Node) -> List: def get_node_shape(node: Node) -> List:
if get_node_name(node) == "split":
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"): if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape return node.meta["tensor_meta"].shape
return None return None
@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
return is_non_compute_node_except_placeholder(node) return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name: str, nodes_list: List) -> int: def find_node_idx(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list): for idx, node in enumerate(nodes_list):
if node.name == name: if node.name == name:
return idx return idx
@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
else: else:
break break
return node_name return node_name
def find_tensor_node(node_list: List[Node]) -> List[Node]:
"""
find tensor nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
return out
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
"""
find tensor and shape nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
out.append(node)
return out

View File

@ -23,6 +23,7 @@ def assert_codegen_run(
concrete_args: List = None, concrete_args: List = None,
max_memory: int = None, max_memory: int = None,
print_mem: bool = False, print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False, print_progress: bool = False,
print_code: bool = False, print_code: bool = False,
) -> List[Dict]: ) -> List[Dict]:
@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen( codegen = AutoChunkCodeGen(
meta_graph, meta_graph,
max_memory=max_memory, max_memory=max_memory,
print_mem=print_mem, print_mem=print_est_mem,
print_progress=print_progress, print_progress=print_progress,
) )
chunks = codegen.chunk_infos chunks = codegen.chunk_infos
@ -61,13 +62,20 @@ def assert_codegen_run(
code = graph.python_code("self").src code = graph.python_code("self").src
if print_code: if print_code:
print(code) print(code)
assert "chunk_result = None; chunk_size = None;" in code assert "chunk_size = None; " in code
# assert result # assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda() model.cuda()
with torch.no_grad(): with torch.no_grad():
out_gm = gm(*inputs) if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs) out_model = model(*inputs)
out_gm = flat_list(out_gm) out_gm = flat_list(out_gm)
out_model = flat_list(out_model) out_model = flat_list(out_model)
@ -85,9 +93,10 @@ def run_test(
max_memory: int, max_memory: int,
get_model: Any, get_model: Any,
get_data: Any, get_data: Any,
print_code: bool, print_code: bool = False,
print_mem: bool, print_mem: bool = False,
print_progress: bool, print_est_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None, get_chunk_target: Any = None,
) -> None: ) -> None:
# launch colossalai # launch colossalai
@ -110,6 +119,7 @@ def run_test(
max_memory=max_memory, max_memory=max_memory,
print_code=print_code, print_code=print_code,
print_mem=print_mem, print_mem=print_mem,
print_est_mem=print_est_mem,
print_progress=print_progress, print_progress=print_progress,
) )

View File

@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict: def get_chunk_target() -> Dict:
return { return {
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)], None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)], (25, 50)],
24: [(118, 123)], 20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
24: [(120, 123)],
} }
@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
get_chunk_target=get_chunk_target, get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
) )
mp.spawn(run_func, nprocs=1) mp.spawn(run_func, nprocs=1)
@ -86,10 +84,12 @@ if __name__ == "__main__":
run_test( run_test(
rank=0, rank=0,
data_args=(32, 64), data_args=(32, 64),
max_memory=20, max_memory=24,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False, print_code=False,
print_mem=False, print_mem=False,
print_est_mem=False,
print_progress=False, print_progress=False,
) )

View File

@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
) )
mp.spawn(run_func, nprocs=1) mp.spawn(run_func, nprocs=1)
@ -81,7 +78,7 @@ if __name__ == "__main__":
run_test( run_test(
rank=0, rank=0,
data_args=(32, 64), data_args=(32, 64),
max_memory=20, max_memory=None,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
print_code=False, print_code=False,

View File

@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict: def get_chunk_target() -> Dict:
return { return {
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250), None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
(33, 46)], (36, 46)],
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)], 20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
24: [(126, 131)], 24: [(128, 131)],
} }
@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
print_code=False, get_chunk_target=get_chunk_target,
print_mem=False,
print_progress=False,
) )
mp.spawn(run_func, nprocs=1) mp.spawn(run_func, nprocs=1)
@ -86,7 +84,7 @@ if __name__ == "__main__":
run_test( run_test(
rank=0, rank=0,
data_args=(32, 64), data_args=(32, 64),
max_memory=20, max_memory=None,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
get_chunk_target=get_chunk_target, get_chunk_target=get_chunk_target,

View File

@ -17,8 +17,8 @@ from test_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2 BATCH_SIZE = 1
SEQ_LENGTH = 256 SEQ_LENGTH = 512
def get_data(shape: tuple) -> Tuple[List, List]: def get_data(shape: tuple) -> Tuple[List, List]:
@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
) )
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) @pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 4.5, 5]) @pytest.mark.parametrize("max_memory", [None, 6, 8])
def test_gpt(model, shape, max_memory): def test_autochunk_gpt(model, shape, max_memory):
run_func = partial( run_func = partial(
run_test, run_test,
data=get_data(shape), data=get_data(shape),
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4), config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
print_code=False,
print_mem=False,
print_progress=False,
) )
mp.spawn(run_func, nprocs=1) mp.spawn(run_func, nprocs=1)
@ -59,7 +56,8 @@ if __name__ == "__main__":
max_memory=None, max_memory=None,
model=GPT2Model, model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=True, print_code=False,
print_mem=True, print_est_mem=False,
print_mem=False,
print_progress=False, print_progress=False,
) )

View File

@ -20,6 +20,7 @@ def assert_codegen_run(
model: Any, model: Any,
data: tuple, data: tuple,
max_memory: int = None, max_memory: int = None,
print_est_mem: bool = False,
print_mem: bool = False, print_mem: bool = False,
print_progress: bool = False, print_progress: bool = False,
print_code: bool = False, print_code: bool = False,
@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen( codegen = AutoChunkCodeGen(
meta_graph, meta_graph,
max_memory=max_memory, max_memory=max_memory,
print_mem=print_mem, print_mem=print_est_mem,
print_progress=print_progress, print_progress=print_progress,
) )
chunks = codegen.chunk_infos chunks = codegen.chunk_infos
@ -61,7 +62,7 @@ def assert_codegen_run(
code = graph.python_code("self").src code = graph.python_code("self").src
if print_code: if print_code:
print(code) print(code)
assert "chunk_result = None; chunk_size = None;" in code assert "chunk_size = None; " in code
# assert result # assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
@ -69,26 +70,44 @@ def assert_codegen_run(
model.cuda().eval() model.cuda().eval()
gm.eval() gm.eval()
with torch.no_grad(): with torch.no_grad():
out_gm = gm(*inputs) if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs) out_model = model(*inputs)
for k in out_model.keys(): assert_allclose(out_model, out_gm)
if torch.is_tensor(out_gm[k]):
assert torch.equal(
out_model[k], out_gm[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
return chunks return chunks
def assert_allclose(out_model: Any, out_gm: Any) -> None:
"""
assert allclose for out
"""
if isinstance(out_model, torch.Tensor):
assert torch.allclose(out_model, out_gm,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_model - out_gm))
elif isinstance(out_model, dict):
for k in out_model.keys():
assert_allclose(out_model[k], out_gm[k])
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
for i, j in zip(out_model, out_gm):
assert_allclose(i, j)
def run_test( def run_test(
rank: int, rank: int,
model: Any, model: Any,
config: Any, config: Any,
data: tuple, data: tuple,
max_memory: int, max_memory: int,
print_code: bool, print_code: bool = False,
print_mem: bool, print_est_mem: bool = False,
print_progress: bool, print_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None, get_chunk_target: Any = None,
) -> None: ) -> None:
model = model(config=config) model = model(config=config)
@ -108,6 +127,7 @@ def run_test(
data=data, data=data,
max_memory=max_memory, max_memory=max_memory,
print_code=print_code, print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem, print_mem=print_mem,
print_progress=print_progress, print_progress=print_progress,
) )
@ -119,5 +139,3 @@ def run_test(
str(chunk_found), str(chunk_found),
str(chunk_target), str(chunk_target),
) )
gpc.destroy()