mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support autochunk on evoformer (#2497)
parent
304f1ba124
commit
ecccc91f21
|
@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
|
|||
"""
|
||||
replace node name
|
||||
"""
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
|
||||
for p in patterns:
|
||||
source = p[0] + name_from + p[1]
|
||||
target = p[0] + name_to + p[1]
|
||||
if source in context:
|
||||
context = context.replace(source, target)
|
||||
break
|
||||
return context
|
||||
|
||||
|
||||
|
@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
|
|||
"""
|
||||
if node_name not in reshape_size_dict:
|
||||
return context
|
||||
for size_name, size_value in reshape_size_dict[node_name].items():
|
||||
context = context.replace(size_name, size_value)
|
||||
context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
|
||||
return context
|
||||
|
||||
|
||||
|
|
|
@ -37,10 +37,10 @@ class EstimateMemory(object):
|
|||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == "placeholder":
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list:
|
||||
if i not in active_list and get_node_shape(n) is not None:
|
||||
active_list.append(i)
|
||||
|
||||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||
|
@ -77,15 +77,11 @@ class EstimateMemory(object):
|
|||
if i in active_list:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_chunk_inputs_size(
|
||||
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
|
||||
):
|
||||
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [
|
||||
find_idx_by_name(i.name, node_list) for i in chunk_input_users
|
||||
]
|
||||
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
|
@ -112,9 +108,7 @@ class EstimateMemory(object):
|
|||
not_contiguous_ops = ["permute"]
|
||||
inherit_contiguous_ops = ["transpose", "view"]
|
||||
|
||||
if node.op == "call_function" and any(
|
||||
n in node.name for n in ["matmul", "reshape"]
|
||||
):
|
||||
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
|
@ -125,9 +119,7 @@ class EstimateMemory(object):
|
|||
# module will just make origin tensor to contiguous
|
||||
if delete:
|
||||
not_contiguous_list.remove(n)
|
||||
elif node.op == "call_method" and any(
|
||||
i in node.name for i in not_contiguous_ops
|
||||
):
|
||||
elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
|
||||
if node not in not_contiguous_list:
|
||||
not_contiguous_list.append(node)
|
||||
return mem
|
||||
|
@ -142,9 +134,7 @@ class EstimateMemory(object):
|
|||
else:
|
||||
return float(chunk_size) / node_shape[chunk_dim]
|
||||
|
||||
def _get_chunk_delete_node_size(
|
||||
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
||||
):
|
||||
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
|
||||
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
|
||||
# return 0
|
||||
if user.op in ("placeholder", "output"):
|
||||
|
@ -196,7 +186,7 @@ class EstimateMemory(object):
|
|||
Returns:
|
||||
act_memory_peak_log (List): peak memory of every node
|
||||
act_memory_after_node_log (List): memory after excuting every node
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
act_memory = 0.0
|
||||
|
@ -212,7 +202,7 @@ class EstimateMemory(object):
|
|||
use_chunk = True if chunk_infos is not None else False
|
||||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_inputs_names = []
|
||||
|
||||
if use_chunk:
|
||||
|
@ -221,23 +211,18 @@ class EstimateMemory(object):
|
|||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [
|
||||
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
|
||||
]
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(
|
||||
chunk_outputs[chunk_region_idx]
|
||||
) / (1024**2)
|
||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
|
@ -262,22 +247,13 @@ class EstimateMemory(object):
|
|||
else:
|
||||
# forward memory
|
||||
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
|
||||
act_memory += (
|
||||
self._get_contiguous_memory(node, not_contiguous_list)
|
||||
* chunk_ratio
|
||||
/ (1024**2)
|
||||
)
|
||||
act_memory += (
|
||||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||
)
|
||||
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
|
||||
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= (
|
||||
self._get_contiguous_memory(node, not_contiguous_list, delete=True)
|
||||
* chunk_ratio
|
||||
/ (1024**2)
|
||||
)
|
||||
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
|
||||
(1024**2))
|
||||
# delete unused vars not in chunk_input_list
|
||||
# we can't delete input nodes until chunk ends
|
||||
if chunk_within:
|
||||
|
@ -288,9 +264,8 @@ class EstimateMemory(object):
|
|||
chunk_inputs_names,
|
||||
) / (1024**2)
|
||||
else:
|
||||
act_memory -= self._get_delete_node_size(
|
||||
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
||||
) / (1024**2)
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
|
||||
chunk_inputs_names) / (1024**2)
|
||||
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
|
@ -298,9 +273,7 @@ class EstimateMemory(object):
|
|||
|
||||
# if node in chunk end nodes, restore chunk settings
|
||||
if use_chunk and idx in chunk_ends:
|
||||
act_memory -= (
|
||||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||
)
|
||||
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
act_memory -= self._get_chunk_inputs_size(
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
|
|
|
@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
|
|||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
|
@ -73,13 +69,11 @@ class SearchChunk(object):
|
|||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
if n.op == "placeholder":
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(
|
||||
self, active_node: List, peak_node: Node, chunk_regions: List
|
||||
) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
|
@ -124,15 +118,9 @@ class SearchChunk(object):
|
|||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (
|
||||
region[0] <= chunk_region_start <= region[1]
|
||||
and chunk_region_end > region[1]
|
||||
):
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (
|
||||
region[0] <= chunk_region_end <= region[1]
|
||||
and chunk_region_start < region[0]
|
||||
):
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
|
@ -164,25 +152,16 @@ class SearchChunk(object):
|
|||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (
|
||||
get_node_shape(end_node)[end_dim] == 1
|
||||
or get_node_shape(start_node)[start_dim] == 1
|
||||
):
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(
|
||||
start_dim, start_node, start_idx, end_dim, end_node
|
||||
):
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(
|
||||
start_idx, end_dim, end_node, end_idx
|
||||
):
|
||||
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim
|
||||
)
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
|
@ -191,9 +170,7 @@ class SearchChunk(object):
|
|||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(
|
||||
self, max_chunk_region: Tuple, peak_node: Node
|
||||
) -> List:
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
|
@ -206,28 +183,23 @@ class SearchChunk(object):
|
|||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||
arg
|
||||
):
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(
|
||||
self.trace_indice.node_list[start_idx]
|
||||
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(
|
||||
input_trace, output_trace, start_idx, end_idx
|
||||
)
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
@ -256,17 +228,12 @@ class SearchChunk(object):
|
|||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
active_node, peak_node, chunk_infos
|
||||
)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
||||
max_chunk_region, mem_peak)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
|
@ -291,9 +258,7 @@ class SearchChunk(object):
|
|||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list
|
||||
)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
|
@ -306,14 +271,10 @@ class SearchChunk(object):
|
|||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos
|
||||
)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
||||
return chunk_infos
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
|
@ -171,7 +176,7 @@ class TraceFlow(object):
|
|||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
if cur_node_chunk_dim is not None:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
|
@ -223,15 +228,32 @@ class TraceFlow(object):
|
|||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
|
||||
"""
|
||||
Get chunk dim for every input node for their every entry, remove unchunked nodes
|
||||
|
||||
Args:
|
||||
inputs (List[Node]): input nodes
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
inputs (List(Node)): new inputs
|
||||
inputs_dim (List): chunk dim for inputs
|
||||
"""
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
for user in input_node.users.keys():
|
||||
# skip non compute
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
# untraced node, mostly non compute
|
||||
if user not in all_node_info:
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
|
@ -245,12 +267,24 @@ class TraceFlow(object):
|
|||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
# remove unchunked inputs
|
||||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
||||
"""
|
||||
get all useless nodes in chunk region and prepose them
|
||||
|
||||
Args:
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
List[Node]: all nodes to be preposed
|
||||
"""
|
||||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
|
@ -276,7 +310,7 @@ class TraceFlow(object):
|
|||
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
for cur_prepose_node_arg in cur_prepose_node.args:
|
||||
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
|
@ -360,19 +394,28 @@ class TraceFlow(object):
|
|||
return chunk_info
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
"""
|
||||
Some shape args in reshape may have changed due to chunk
|
||||
reassgin those changed shape
|
||||
"""
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_indice.indice_view_list[node]
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
reshape_size[node.name] = {}
|
||||
new_shape = ""
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
|
||||
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
|
||||
else:
|
||||
if isinstance(reshape_arg, int):
|
||||
new_shape += "%s, " % str(reshape_arg)
|
||||
else:
|
||||
new_shape += "%s, " % reshape_arg.name
|
||||
new_shape = new_shape[:-2]
|
||||
origin_shape = str(reshape_args)[1:-1]
|
||||
reshape_size[node.name] = [origin_shape, new_shape]
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
|
|||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
|
@ -28,7 +28,7 @@ class TraceIndice(object):
|
|||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List) -> None:
|
||||
def __init__(self, node_list: List[Node]) -> None:
|
||||
self.node_list = node_list
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
|
@ -198,7 +198,7 @@ class TraceIndice(object):
|
|||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node, node_idx, input_node=None):
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
|
@ -216,7 +216,7 @@ class TraceIndice(object):
|
|||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_indice(self, node, node_idx):
|
||||
def _assign_all_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Add new indice for all node's dims.
|
||||
|
||||
|
@ -232,7 +232,7 @@ class TraceIndice(object):
|
|||
new_trace.append(self._add_indice())
|
||||
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||
|
||||
def _assign_transpose_indice(self, node, node_idx):
|
||||
def _assign_transpose_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
|
@ -249,7 +249,7 @@ class TraceIndice(object):
|
|||
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_indice(self, node, node_idx):
|
||||
def _assign_permute_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
|
@ -259,14 +259,14 @@ class TraceIndice(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = unflat_list(node.args[1:])
|
||||
permute_dim = flat_list(node.args[1:])
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
self._inherit_indice(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_indice(self, node, node_idx):
|
||||
def _assign_linear_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for linear op.
|
||||
1. copy trace from input node and change last indice accroding to weight
|
||||
|
@ -287,7 +287,7 @@ class TraceIndice(object):
|
|||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node, node_idx):
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||
|
@ -393,7 +393,7 @@ class TraceIndice(object):
|
|||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||
|
||||
def _assign_unsqueeze_indice(self, node, node_idx):
|
||||
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
@ -404,9 +404,13 @@ class TraceIndice(object):
|
|||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._add_dim(node_idx, node.args[1])
|
||||
dim_idx = node.args[1]
|
||||
# unsqueeze(-1) = unsqueeze(shape_num + 1)
|
||||
if dim_idx < 0:
|
||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_dropout_indice(self, node, node_idx):
|
||||
def _assign_dropout_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
@ -417,7 +421,7 @@ class TraceIndice(object):
|
|||
"""
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node, node_idx):
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
|
@ -428,7 +432,47 @@ class TraceIndice(object):
|
|||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_view_reshape_indice(self, node, node_idx):
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for getitem.
|
||||
getitem can act like slice sometimes
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
node_args = flat_list(node.args[1:])
|
||||
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
|
||||
return
|
||||
|
||||
# node args should be like [Ellipsis, slice(start, step, end), None]
|
||||
node_shape = get_node_shape(node)
|
||||
origin_idx_count = 0
|
||||
new_idx_count = 0
|
||||
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
|
||||
for _ in range(new_dim_num):
|
||||
self._del_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
for _, node_arg in enumerate(node_args):
|
||||
node_arg_str = str(node_arg)
|
||||
# Ellipsis means [..., ]
|
||||
if "Ellipsis" == node_arg_str:
|
||||
shape_gap = len(node_shape) - len(node_args) + 1
|
||||
origin_idx_count += shape_gap
|
||||
new_idx_count += shape_gap
|
||||
# slice(None, None, None) means all indexes, doesn't support other slice
|
||||
elif "slice(None, None, None)" == node_arg_str:
|
||||
origin_idx_count += 1
|
||||
new_idx_count += 1
|
||||
# None means a new dim
|
||||
elif "None" == node_arg_str:
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
new_idx_count += 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
|
@ -447,7 +491,7 @@ class TraceIndice(object):
|
|||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
unflated_args = unflat_list(node.args)
|
||||
unflated_args = flat_list(node.args)
|
||||
for i in range(1, len(unflated_args)):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
|
@ -544,6 +588,8 @@ class TraceIndice(object):
|
|||
self._assign_einsum_indice(node, idx)
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "getitem" in node.name:
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
continue
|
||||
else:
|
||||
|
|
|
@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def unflat_list(inputs):
|
||||
def flat_list(inputs):
|
||||
"""
|
||||
unflat a list by recursion
|
||||
flat a list by recursion
|
||||
"""
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(unflat_list(i))
|
||||
res.extend(flat_list(i))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
|
@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
|
|||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
node_args = flat_list(node.args[1:])
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -40,15 +45,15 @@ def get_node_shape(node):
|
|||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
if "output" in node.op:
|
||||
return False
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
|
|
|
@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
|||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# model = model.cuda()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# gm(node1, pair1)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# node_mask1 = node_mask.clone()
|
||||
# pair_mask1 = pair_mask.clone()
|
||||
# gm(node1, pair1, node_mask1, pair_mask1)
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print(
|
||||
# "autochunk now mem:%.2f max mem:%.2f"
|
||||
# % (new_now_mem - now_mem, new_max_mem - now_mem)
|
||||
# )
|
||||
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
|
@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||
MetaTensor(node_mask, fake_device="cuda:0"),
|
||||
MetaTensor(pair_mask, fake_device="cuda:0"),
|
||||
)
|
||||
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
|
@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
# graph.set_codegen(codegen)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
assert "chunk_size" in code
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_evoformer_codegen(msa_len, pair_len, max_memory):
|
||||
|
@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_evoformer_codegen(0, 32, 64, 25)
|
||||
_test_evoformer_codegen(0, 32, 64, 24)
|
||||
|
|
|
@ -13,7 +13,7 @@ except:
|
|||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import ColoTracer, symbolic_trace
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
|||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
# for memory test
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# gm(node1, pair1)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print(
|
||||
# "autochunk now mem:%.2f max mem:%.2f"
|
||||
# % (new_now_mem - now_mem, new_max_mem - now_mem)
|
||||
# )
|
||||
|
||||
# test forward
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
|
@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
# meta info prop
|
||||
meta_graph = symbolic_trace(model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
}) # must use symbolic_trace
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
|
@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
assert "chunk_size" in code
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair)
|
||||
gpc.destroy()
|
||||
|
|
|
@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||
str(target_regions),
|
||||
)
|
||||
for region in target_regions:
|
||||
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
max_memory,
|
||||
str(max_memory),
|
||||
)
|
||||
for region in found_regions:
|
||||
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
max_memory,
|
||||
str(max_memory),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue