[autochunk] support autochunk on evoformer (#2497)

pull/2502/head
oahzxl 2023-01-19 11:41:00 +08:00 committed by GitHub
parent 304f1ba124
commit ecccc91f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 200 additions and 188 deletions

View File

@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
"""
replace node name
"""
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]
if source in context:
context = context.replace(source, target)
break
return context
@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
"""
if node_name not in reshape_size_dict:
return context
for size_name, size_value in reshape_size_dict[node_name].items():
context = context.replace(size_name, size_value)
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
return context

View File

@ -37,10 +37,10 @@ class EstimateMemory(object):
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == "placeholder":
if n.op == "placeholder" and get_node_shape(n) is not None:
new_active.append(n.name)
for i in new_active:
if i not in active_list:
if i not in active_list and get_node_shape(n) is not None:
active_list.append(i)
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
@ -77,15 +77,11 @@ class EstimateMemory(object):
if i in active_list:
active_list.remove(i)
def _get_chunk_inputs_size(
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
):
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [
find_idx_by_name(i.name, node_list) for i in chunk_input_users
]
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
@ -112,9 +108,7 @@ class EstimateMemory(object):
not_contiguous_ops = ["permute"]
inherit_contiguous_ops = ["transpose", "view"]
if node.op == "call_function" and any(
n in node.name for n in ["matmul", "reshape"]
):
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
@ -125,9 +119,7 @@ class EstimateMemory(object):
# module will just make origin tensor to contiguous
if delete:
not_contiguous_list.remove(n)
elif node.op == "call_method" and any(
i in node.name for i in not_contiguous_ops
):
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
if node not in not_contiguous_list:
not_contiguous_list.append(node)
return mem
@ -142,9 +134,7 @@ class EstimateMemory(object):
else:
return float(chunk_size) / node_shape[chunk_dim]
def _get_chunk_delete_node_size(
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
):
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if user.op in ("placeholder", "output"):
@ -196,7 +186,7 @@ class EstimateMemory(object):
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory = 0.0
@ -212,7 +202,7 @@ class EstimateMemory(object):
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_names = []
if use_chunk:
@ -221,23 +211,18 @@ class EstimateMemory(object):
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i
]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(
chunk_outputs[chunk_region_idx]
) / (1024**2)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
# determine chunk ratio for current node
if chunk_within:
@ -262,22 +247,13 @@ class EstimateMemory(object):
else:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory += (
self._get_contiguous_memory(node, not_contiguous_list)
* chunk_ratio
/ (1024**2)
)
act_memory += (
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= (
self._get_contiguous_memory(node, not_contiguous_list, delete=True)
* chunk_ratio
/ (1024**2)
)
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
(1024**2))
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if chunk_within:
@ -288,9 +264,8 @@ class EstimateMemory(object):
chunk_inputs_names,
) / (1024**2)
else:
act_memory -= self._get_delete_node_size(
node, user_to_last_uses_no_free_var, chunk_inputs_names
) / (1024**2)
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
chunk_inputs_names) / (1024**2)
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
@ -298,9 +273,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
act_memory -= (
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],

View File

@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@ -73,13 +69,11 @@ class SearchChunk(object):
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder":
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
def _search_max_chunk_region(
self, active_node: List, peak_node: Node, chunk_regions: List
) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
"""
Search max chunk region according to peak memory node
@ -124,15 +118,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (
region[0] <= chunk_region_start <= region[1]
and chunk_region_end > region[1]
):
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (
region[0] <= chunk_region_end <= region[1]
and chunk_region_start < region[0]
):
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@ -164,25 +152,16 @@ class SearchChunk(object):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (
get_node_shape(end_node)[end_dim] == 1
or get_node_shape(start_node)[start_dim] == 1
):
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# check index source align
if not self.trace_flow.check_index_source(
start_dim, start_node, start_idx, end_dim, end_node
):
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(
start_idx, end_dim, end_node, end_idx
):
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(
start_idx, start_dim, end_idx, end_dim
)
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
@ -191,9 +170,7 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(
self, max_chunk_region: Tuple, peak_node: Node
) -> List:
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
@ -206,28 +183,23 @@ class SearchChunk(object):
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg
):
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(
self.trace_indice.node_list[start_idx]
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
continue
# select free dim
chunk_info = self._find_chunk_info(
input_trace, output_trace, start_idx, end_idx
)
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
@ -256,17 +228,12 @@ class SearchChunk(object):
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
active_node, peak_node, chunk_infos
)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region
@ -291,9 +258,7 @@ class SearchChunk(object):
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
mem_peak = init_mem_peak
while True:
@ -306,14 +271,10 @@ class SearchChunk(object):
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos, print_mem=True
)
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
return chunk_infos

View File

@ -1,8 +1,13 @@
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
flat_list,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
@ -171,7 +176,7 @@ class TraceFlow(object):
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim:
if cur_node_chunk_dim is not None:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
@ -223,15 +228,32 @@ class TraceFlow(object):
cur_node_list = next_node_list
return all_node_info
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim = []
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
continue
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
@ -245,12 +267,24 @@ class TraceFlow(object):
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
# remove unchunked inputs
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
@ -276,7 +310,7 @@ class TraceFlow(object):
for cur_prepose_node in tmp_cur_prepose_nodes:
if prepose_flag == False:
break
for cur_prepose_node_arg in cur_prepose_node.args:
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
@ -360,19 +394,28 @@ class TraceFlow(object):
return chunk_info
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:]
reshape_log = self.trace_indice.indice_view_list[node]
reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {}
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
else:
if isinstance(reshape_arg, int):
new_shape += "%s, " % str(reshape_arg)
else:
new_shape += "%s, " % reshape_arg.name
new_shape = new_shape[:-2]
origin_shape = str(reshape_args)[1:-1]
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
class TraceIndice(object):
@ -28,7 +28,7 @@ class TraceIndice(object):
node_list (List)
"""
def __init__(self, node_list: List) -> None:
def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
@ -198,7 +198,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node, node_idx, input_node=None):
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
"""
Assign node's trace as its input node.
@ -216,7 +216,7 @@ class TraceIndice(object):
self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node, node_idx):
def _assign_all_indice(self, node: Node, node_idx: int):
"""
Add new indice for all node's dims.
@ -232,7 +232,7 @@ class TraceIndice(object):
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node, node_idx):
def _assign_transpose_indice(self, node: Node, node_idx: int):
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
@ -249,7 +249,7 @@ class TraceIndice(object):
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node, node_idx):
def _assign_permute_indice(self, node: Node, node_idx: int):
"""
Assign indice for permute op.
1. swap input's dim according to permute args
@ -259,14 +259,14 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
permute_dim = unflat_list(node.args[1:])
permute_dim = flat_list(node.args[1:])
input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node, node_idx):
def _assign_linear_indice(self, node: Node, node_idx: int):
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
@ -287,7 +287,7 @@ class TraceIndice(object):
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node, node_idx):
def _assign_matmul_indice(self, node: Node, node_idx: int):
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@ -393,7 +393,7 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node, node_idx):
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@ -404,9 +404,13 @@ class TraceIndice(object):
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1])
dim_idx = node.args[1]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if dim_idx < 0:
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node, node_idx):
def _assign_dropout_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@ -417,7 +421,7 @@ class TraceIndice(object):
"""
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node, node_idx):
def _assign_ones_like_indice(self, node: Node, node_idx: int):
"""
Assign indice for oneslike op.
1. assign new indice for all dim
@ -428,7 +432,47 @@ class TraceIndice(object):
"""
self._assign_all_indice(node, node_idx)
def _assign_view_reshape_indice(self, node, node_idx):
def _assign_getitem_indice(self, node: Node, node_idx: int):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args = flat_list(node.args[1:])
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape = get_node_shape(node)
origin_idx_count = 0
new_idx_count = 0
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
node_arg_str = str(node_arg)
# Ellipsis means [..., ]
if "Ellipsis" == node_arg_str:
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif "slice(None, None, None)" == node_arg_str:
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
@ -447,7 +491,7 @@ class TraceIndice(object):
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
unflated_args = unflat_list(node.args)
unflated_args = flat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
@ -544,6 +588,8 @@ class TraceIndice(object):
self._assign_einsum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue
else:

View File

@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
def unflat_list(inputs):
def flat_list(inputs):
"""
unflat a list by recursion
flat a list by recursion
"""
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(unflat_list(i))
res.extend(flat_list(i))
else:
res.append(i)
return res
@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
def is_non_compute_node(node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
return True
return False
@ -40,15 +45,15 @@ def get_node_shape(node):
def is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list):

View File

@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"_mask_trans": True,
},
)
# graph.set_codegen(codegen)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
assert "chunk_size" in code
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory):
@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 25)
_test_evoformer_codegen(0, 32, 64, 24)

View File

@ -13,7 +13,7 @@ except:
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx import ColoTracer, symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
with torch.no_grad():
non_fx_out = model(node, pair)
fx_out = gm(node, pair)
@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
# meta info prop
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace the module and replace codegen
graph = ColoTracer().trace(
model,
@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"pair": pair.to(torch.device("meta")),
},
)
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
assert "chunk_size" in code
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair)
gpc.destroy()

View File

@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
str(target_regions),
)
for region in target_regions:
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
str(region),
msa_len,
pair_len,
max_memory,
str(max_memory),
)
for region in found_regions:
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
str(region),
msa_len,
pair_len,
max_memory,
str(max_memory),
)