[autochunk] support autochunk on evoformer (#2497)

pull/2502/head
oahzxl 2023-01-19 11:41:00 +08:00 committed by GitHub
parent 304f1ba124
commit ecccc91f21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 200 additions and 188 deletions

View File

@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
""" """
replace node name replace node name
""" """
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
for p in patterns: for p in patterns:
source = p[0] + name_from + p[1] source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1] target = p[0] + name_to + p[1]
if source in context: if source in context:
context = context.replace(source, target) context = context.replace(source, target)
break
return context return context
@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
""" """
if node_name not in reshape_size_dict: if node_name not in reshape_size_dict:
return context return context
for size_name, size_value in reshape_size_dict[node_name].items(): context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
context = context.replace(size_name, size_value)
return context return context

View File

@ -37,10 +37,10 @@ class EstimateMemory(object):
def _add_active_node(self, n, active_list): def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1] new_active = self._get_output_node(n)[1]
if n.op == "placeholder": if n.op == "placeholder" and get_node_shape(n) is not None:
new_active.append(n.name) new_active.append(n.name)
for i in new_active: for i in new_active:
if i not in active_list: if i not in active_list and get_node_shape(n) is not None:
active_list.append(i) active_list.append(i)
def _get_delete_node(self, user, user_to_last_uses, to_keep=None): def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
@ -77,15 +77,11 @@ class EstimateMemory(object):
if i in active_list: if i in active_list:
active_list.remove(i) active_list.remove(i)
def _get_chunk_inputs_size( def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
):
nodes_to_delete = [] nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk: for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys() chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [ chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
find_idx_by_name(i.name, node_list) for i in chunk_input_users
]
if all(i <= chunk_end_idx for i in chunk_input_users_idx): if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete: if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input) nodes_to_delete.append(chunk_input)
@ -112,9 +108,7 @@ class EstimateMemory(object):
not_contiguous_ops = ["permute"] not_contiguous_ops = ["permute"]
inherit_contiguous_ops = ["transpose", "view"] inherit_contiguous_ops = ["transpose", "view"]
if node.op == "call_function" and any( if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
n in node.name for n in ["matmul", "reshape"]
):
for n in node.args: for n in node.args:
if n in not_contiguous_list: if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy # matmul won't change origin tensor, but create a tmp copy
@ -125,9 +119,7 @@ class EstimateMemory(object):
# module will just make origin tensor to contiguous # module will just make origin tensor to contiguous
if delete: if delete:
not_contiguous_list.remove(n) not_contiguous_list.remove(n)
elif node.op == "call_method" and any( elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
i in node.name for i in not_contiguous_ops
):
if node not in not_contiguous_list: if node not in not_contiguous_list:
not_contiguous_list.append(node) not_contiguous_list.append(node)
return mem return mem
@ -142,9 +134,7 @@ class EstimateMemory(object):
else: else:
return float(chunk_size) / node_shape[chunk_dim] return float(chunk_size) / node_shape[chunk_dim]
def _get_chunk_delete_node_size( def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']): # if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0 # return 0
if user.op in ("placeholder", "output"): if user.op in ("placeholder", "output"):
@ -221,23 +211,18 @@ class EstimateMemory(object):
chunk_ends = [i[1] for i in chunk_regions] chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos] chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ chunk_inputs_names = [j.name for i in chunk_inputs for j in i
j.name for i in chunk_inputs_non_chunk for j in i ] + [j.name for i in chunk_inputs_non_chunk for j in i]
]
chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [ chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
]
for idx, node in enumerate(node_list): for idx, node in enumerate(node_list):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor # if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts: if use_chunk and idx in chunk_starts:
chunk_within = True chunk_within = True
chunk_region_idx = chunk_starts.index(idx) chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size( act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
chunk_outputs[chunk_region_idx]
) / (1024**2)
# determine chunk ratio for current node # determine chunk ratio for current node
if chunk_within: if chunk_within:
@ -262,22 +247,13 @@ class EstimateMemory(object):
else: else:
# forward memory # forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory += ( act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
self._get_contiguous_memory(node, not_contiguous_list) act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
* chunk_ratio
/ (1024**2)
)
act_memory += (
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
# record max act memory # record max act memory
act_memory_peak_log.append(act_memory) act_memory_peak_log.append(act_memory)
# delete useless memory # delete useless memory
act_memory -= ( act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
self._get_contiguous_memory(node, not_contiguous_list, delete=True) (1024**2))
* chunk_ratio
/ (1024**2)
)
# delete unused vars not in chunk_input_list # delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends # we can't delete input nodes until chunk ends
if chunk_within: if chunk_within:
@ -288,9 +264,8 @@ class EstimateMemory(object):
chunk_inputs_names, chunk_inputs_names,
) / (1024**2) ) / (1024**2)
else: else:
act_memory -= self._get_delete_node_size( act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
node, user_to_last_uses_no_free_var, chunk_inputs_names chunk_inputs_names) / (1024**2)
) / (1024**2)
# log active node, only effective without chunk # log active node, only effective without chunk
self._add_active_node(node, active_node_list) self._add_active_node(node, active_node_list)
@ -298,9 +273,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings # if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends: if use_chunk and idx in chunk_ends:
act_memory -= ( act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
self._get_output_node_size(node) * chunk_ratio / (1024**2)
)
act_memory -= self._get_chunk_inputs_size( act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx], chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx], chunk_inputs_non_chunk[chunk_region_idx],

View File

@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import ( from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class SearchChunk(object): class SearchChunk(object):
@ -73,13 +69,11 @@ class SearchChunk(object):
""" """
free_var_idx = [] free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list): for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder": if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx) free_var_idx.append(idx)
return free_var_idx return free_var_idx
def _search_max_chunk_region( def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
self, active_node: List, peak_node: Node, chunk_regions: List
) -> Tuple:
""" """
Search max chunk region according to peak memory node Search max chunk region according to peak memory node
@ -124,15 +118,9 @@ class SearchChunk(object):
region = i["region"] region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]: if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None return None
elif ( elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
region[0] <= chunk_region_start <= region[1]
and chunk_region_end > region[1]
):
chunk_region_start = region[1] + 1 chunk_region_start = region[1] + 1
elif ( elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
region[0] <= chunk_region_end <= region[1]
and chunk_region_start < region[0]
):
chunk_region_end = region[0] - 1 chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
@ -164,25 +152,16 @@ class SearchChunk(object):
for start_node, start_trace in start_traces.items(): for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]): for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1 # dim size cannot be 1
if ( if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
get_node_shape(end_node)[end_dim] == 1
or get_node_shape(start_node)[start_dim] == 1
):
continue continue
# check index source align # check index source align
if not self.trace_flow.check_index_source( if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
start_dim, start_node, start_idx, end_dim, end_node
):
continue continue
# check index copmute # check index copmute
if not self.trace_flow.check_index_compute( if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
start_idx, end_dim, end_node, end_idx
):
continue continue
# flow search # flow search
chunk_info = self.trace_flow.flow_search( chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
start_idx, start_dim, end_idx, end_dim
)
if chunk_info is None: if chunk_info is None:
continue continue
# check index copmute # check index copmute
@ -191,9 +170,7 @@ class SearchChunk(object):
chunk_infos.append(chunk_info) chunk_infos.append(chunk_info)
return chunk_infos return chunk_infos
def _search_possible_chunk_regions( def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
self, max_chunk_region: Tuple, peak_node: Node
) -> List:
""" """
Search every possible region within the max chunk region. Search every possible region within the max chunk region.
@ -210,24 +187,19 @@ class SearchChunk(object):
for _, n in enumerate(self.trace_indice.node_list): for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {} cur_trace = {}
for arg in n.args: for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder( if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
arg
):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg) cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace) input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1): for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes # skip non compute nodes
if is_non_compute_node( if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[start_idx] self.trace_indice.node_list[end_idx]):
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
continue continue
# select free dim # select free dim
chunk_info = self._find_chunk_info( chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
input_trace, output_trace, start_idx, end_idx
)
if len(chunk_info) > 0: if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info) possible_chunk_region.extend(chunk_info)
return possible_chunk_region return possible_chunk_region
@ -256,17 +228,12 @@ class SearchChunk(object):
best_chunk_region (Dict) best_chunk_region (Dict)
""" """
peak_node = self._find_peak_node(mem_peak) peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region( max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
active_node, peak_node, chunk_infos
)
if max_chunk_region == None: if max_chunk_region == None:
return None return None
possible_chunk_regions = self._search_possible_chunk_regions( possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
max_chunk_region, peak_node best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
) max_chunk_region, mem_peak)
best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region return best_chunk_region
@ -291,9 +258,7 @@ class SearchChunk(object):
init_mem_peak, init_mem_peak,
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem( ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
self.trace_indice.node_list
)
mem_peak = init_mem_peak mem_peak = init_mem_peak
while True: while True:
@ -306,14 +271,10 @@ class SearchChunk(object):
mem_peak, mem_peak,
_, _,
active_node, active_node,
) = self.estimate_memory.estimate_chunk_inference_mem( ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
self.trace_indice.node_list, chunk_infos
)
if self._stop_search(init_mem_peak, mem_peak): if self._stop_search(init_mem_peak, mem_peak):
break break
if self.print_mem: if self.print_mem:
self.print_mem = False self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem( self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
self.trace_indice.node_list, chunk_infos, print_mem=True
)
return chunk_infos return chunk_infos

View File

@ -1,8 +1,13 @@
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import ( from .utils import (
find_chunk_all_input_nodes, find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes, find_chunk_compute_input_and_output_nodes,
find_idx_by_name, find_idx_by_name,
flat_list,
get_node_shape, get_node_shape,
is_non_compute_node, is_non_compute_node,
is_non_compute_node_except_placeholder, is_non_compute_node_except_placeholder,
@ -171,7 +176,7 @@ class TraceFlow(object):
# get cur node info # get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim: if cur_node_chunk_dim is not None:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else: else:
@ -223,15 +228,32 @@ class TraceFlow(object):
cur_node_list = next_node_list cur_node_list = next_node_list
return all_node_info return all_node_info
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim = [] inputs_dim = []
remove_inputs = [] remove_inputs = []
for input_node in inputs: for input_node in inputs:
input_dict = {} input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys(): for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user): if is_non_compute_node(user):
continue continue
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx: if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"] chunk_dim = all_node_info[user]["chunk_dim"]
@ -245,12 +267,24 @@ class TraceFlow(object):
remove_inputs.append(input_node) remove_inputs.append(input_node)
else: else:
inputs_dim.append(input_dict) inputs_dim.append(input_dict)
# remove unchunked inputs
for i in remove_inputs: for i in remove_inputs:
if i in inputs: if i in inputs:
inputs.remove(i) inputs.remove(i)
return inputs, inputs_dim return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes # get all possible prepose nodes
maybe_prepose_nodes = [] maybe_prepose_nodes = []
for node, node_info in all_node_info.items(): for node, node_info in all_node_info.items():
@ -276,7 +310,7 @@ class TraceFlow(object):
for cur_prepose_node in tmp_cur_prepose_nodes: for cur_prepose_node in tmp_cur_prepose_nodes:
if prepose_flag == False: if prepose_flag == False:
break break
for cur_prepose_node_arg in cur_prepose_node.args: for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
if type(cur_prepose_node_arg) != type(cur_prepose_node): if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue continue
# out of loop # out of loop
@ -360,19 +394,28 @@ class TraceFlow(object):
return chunk_info return chunk_info
def _reassgin_reshape_size(self, chunk_info): def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region = chunk_info["region"] chunk_region = chunk_info["region"]
reshape_size = {} reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]): if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:] reshape_args = flat_list(node.args[1:])
reshape_log = self.trace_indice.indice_view_list[node]
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {} new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim: if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape) new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
else:
if isinstance(reshape_arg, int):
new_shape += "%s, " % str(reshape_arg)
else:
new_shape += "%s, " % reshape_arg.name
new_shape = new_shape[:-2]
origin_shape = str(reshape_args)[1:-1]
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
class TraceIndice(object): class TraceIndice(object):
@ -28,7 +28,7 @@ class TraceIndice(object):
node_list (List) node_list (List)
""" """
def __init__(self, node_list: List) -> None: def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list self.node_list = node_list
self.indice_trace_list = self._init_indice_trace_list() self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {} self.indice_view_list = {}
@ -198,7 +198,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list) node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"] return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node, node_idx, input_node=None): def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
""" """
Assign node's trace as its input node. Assign node's trace as its input node.
@ -216,7 +216,7 @@ class TraceIndice(object):
self._inherit_all_computation(input_node, node) self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node, node_idx): def _assign_all_indice(self, node: Node, node_idx: int):
""" """
Add new indice for all node's dims. Add new indice for all node's dims.
@ -232,7 +232,7 @@ class TraceIndice(object):
new_trace.append(self._add_indice()) new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node, node_idx): def _assign_transpose_indice(self, node: Node, node_idx: int):
""" """
Assign indice for transpose op. Assign indice for transpose op.
1. swap input's dim according to transpose args 1. swap input's dim according to transpose args
@ -249,7 +249,7 @@ class TraceIndice(object):
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node, node_idx): def _assign_permute_indice(self, node: Node, node_idx: int):
""" """
Assign indice for permute op. Assign indice for permute op.
1. swap input's dim according to permute args 1. swap input's dim according to permute args
@ -259,14 +259,14 @@ class TraceIndice(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
permute_dim = unflat_list(node.args[1:]) permute_dim = flat_list(node.args[1:])
input_node = node.args[0] input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node) self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim): for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx) self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node, node_idx): def _assign_linear_indice(self, node: Node, node_idx: int):
""" """
Assign indice for linear op. Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight 1. copy trace from input node and change last indice accroding to weight
@ -287,7 +287,7 @@ class TraceIndice(object):
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node, node_idx): def _assign_matmul_indice(self, node: Node, node_idx: int):
""" """
Assign indice for matmul op. Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length) 1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@ -393,7 +393,7 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx) self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]]) self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node, node_idx): def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
""" """
Assign indice for unsqueeze op. Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim 1. assign new indice for unsqueeze dim
@ -404,9 +404,13 @@ class TraceIndice(object):
""" """
self._del_dim(node_idx, -1) self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1]) dim_idx = node.args[1]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if dim_idx < 0:
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node, node_idx): def _assign_dropout_indice(self, node: Node, node_idx: int):
""" """
Assign indice for unsqueeze op. Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim 1. assign new indice for unsqueeze dim
@ -417,7 +421,7 @@ class TraceIndice(object):
""" """
self._assign_indice_as_input(node, node_idx) self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node, node_idx): def _assign_ones_like_indice(self, node: Node, node_idx: int):
""" """
Assign indice for oneslike op. Assign indice for oneslike op.
1. assign new indice for all dim 1. assign new indice for all dim
@ -428,7 +432,47 @@ class TraceIndice(object):
""" """
self._assign_all_indice(node, node_idx) self._assign_all_indice(node, node_idx)
def _assign_view_reshape_indice(self, node, node_idx): def _assign_getitem_indice(self, node: Node, node_idx: int):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args = flat_list(node.args[1:])
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape = get_node_shape(node)
origin_idx_count = 0
new_idx_count = 0
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
node_arg_str = str(node_arg)
# Ellipsis means [..., ]
if "Ellipsis" == node_arg_str:
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif "slice(None, None, None)" == node_arg_str:
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
""" """
Assign indice for view and reshape op. Assign indice for view and reshape op.
1. get origin shape and target shape by meta info. 1. get origin shape and target shape by meta info.
@ -447,7 +491,7 @@ class TraceIndice(object):
origin_node = node.args[0] origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = [] target_shape = []
unflated_args = unflat_list(node.args) unflated_args = flat_list(node.args)
for i in range(1, len(unflated_args)): for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int): if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i]) target_shape.append(unflated_args[i])
@ -544,6 +588,8 @@ class TraceIndice(object):
self._assign_einsum_indice(node, idx) self._assign_einsum_indice(node, idx)
elif "layer_norm" in node.name: elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx) self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]): elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue continue
else: else:

View File

@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node from torch.fx.node import Node
def unflat_list(inputs): def flat_list(inputs):
""" """
unflat a list by recursion flat a list by recursion
""" """
res = [] res = []
for i in inputs: for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(unflat_list(i)) res.extend(flat_list(i))
else: else:
res.append(i) res.append(i)
return res return res
@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
def is_non_compute_node(node): def is_non_compute_node(node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
i in node.name for i in ["getitem", "getattr"]): return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
return True return True
return False return False
@ -40,15 +45,15 @@ def get_node_shape(node):
def is_non_compute_node_except_placeholder(node): def is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]): if "placeholder" in node.op:
return True
return False return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node): def is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]): if "output" in node.op:
return True
return False return False
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list): def find_idx_by_name(name, nodes_list):

View File

@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test # for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats() # torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2 # now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad(): # with torch.no_grad():
# node1 = node.clone() # node1 = node.clone()
# pair1 = pair.clone() # pair1 = pair.clone()
# gm(node1, pair1) # node_mask1 = node_mask.clone()
# new_now_mem = torch.cuda.memory_allocated() / 1024**2 # pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print( # print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward # test forward
model = model.cuda() model = model.cuda()
@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"),
) )
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile # trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"_mask_trans": True, "_mask_trans": True,
}, },
) )
# graph.set_codegen(codegen) graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()
# assert we have inserted chunk # assert we have inserted chunk
code = graph.python_code("self").src code = graph.python_code("self").src
assert "chunk_size" in code
# print(code) # print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask) _test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy() gpc.destroy()
@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64]) @pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory): def test_evoformer_codegen(msa_len, pair_len, max_memory):
@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__": if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 25) _test_evoformer_codegen(0, 32, 64, 24)

View File

@ -13,7 +13,7 @@ except:
import colossalai import colossalai
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer, symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
with torch.no_grad(): with torch.no_grad():
non_fx_out = model(node, pair) non_fx_out = model(node, pair)
fx_out = gm(node, pair) fx_out = gm(node, pair)
@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
# meta info prop
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace the module and replace codegen # trace the module and replace codegen
graph = ColoTracer().trace( graph = ColoTracer().trace(
model, model,
@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"pair": pair.to(torch.device("meta")), "pair": pair.to(torch.device("meta")),
}, },
) )
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
graph.set_codegen(codegen) graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()
# assert we have inserted chunk # assert we have inserted chunk
code = graph.python_code("self").src code = graph.python_code("self").src
assert "chunk_size" in code
# print(code) # print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair) _test_fwd(model, gm, node, pair)
gpc.destroy() gpc.destroy()

View File

@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
str(target_regions), str(target_regions),
) )
for region in target_regions: for region in target_regions:
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
str(region), str(region),
msa_len, msa_len,
pair_len, pair_len,
max_memory, str(max_memory),
) )
for region in found_regions: for region in found_regions:
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
str(region), str(region),
msa_len, msa_len,
pair_len, pair_len,
max_memory, str(max_memory),
) )