[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
pull/3057/head
Xuanlei Zhao 2023-03-08 16:22:30 +08:00 committed by GitHub
parent b51bfec357
commit 2ca9728cbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 294 additions and 422 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
@ -216,14 +216,13 @@ def _add_node_slice(
return body
def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func,
delete_unused_value_func,
search_chunk: SearchChunk,
chunk_infos: List,
):
def emit_code_with_chunk(body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
"""
Emit code with chunk according to chunk_infos.
@ -260,6 +259,9 @@ def emit_code_with_chunk(
region_idx = 0
within_chunk_region = False
if eval_mem:
body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")
while node_idx < len(node_list):
node = node_list[node_idx]
@ -289,10 +291,18 @@ def emit_code_with_chunk(
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
# generate chunk region end
if node_idx in chunk_ends:
@ -312,8 +322,10 @@ if AUTOCHUNK_AVAILABLE:
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
print_progress: bool = False,
eval_mem: bool = False) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
@ -511,14 +523,8 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(
body,
nodes,
emit_node,
delete_unused_values,
self.search_chunk,
self.chunk_infos,
)
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body

View File

@ -2,11 +2,11 @@ import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from torch.fx.node import Node, map_arg
from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
from .utils import NodeMgr, get_node_shape, is_non_memory_node
class EstimateMemory(object):
@ -14,102 +14,85 @@ class EstimateMemory(object):
Estimate memory with chunk
"""
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
def __init__(self) -> None:
pass
def _get_meta_node_size(self, x):
def _get_node_size(self, x: Node) -> float:
"""
return node size in MB
"""
x = x.meta["tensor_meta"]
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
return x
if not hasattr(x, "numel"):
out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x])
else:
out = x.numel * torch.tensor([], dtype=x.dtype).element_size()
out = float(out) / 1024**2
return out
def _get_output_node(self, n):
out_size = activation_size(n.meta["fwd_out"])
out_node = [n.name] if out_size > 0 else []
return out_size, out_node
def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None:
"""
add an active node and its shape to active node dict
"""
if get_node_shape(n) is None:
return
if n.op == "placeholder":
return
if n not in active_nodes:
node_size = self._get_node_size(n) * chunk_ratio
active_nodes[n] = node_size
def _get_output_node_size(self, n):
return self._get_output_node(n)[0]
def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
"""
build delete node dict, means node should be deleted at what time
"""
delete_node_dict = {}
for idx, node in enumerate(node_mgr.get_node_list()):
# skip non shape node
if get_node_shape(node) is None:
continue
# dont remove free nodes
elif node.op == "placeholder":
delete_node_dict[node] = len(node_mgr.get_node_list())
# node no user
elif len(node.users) == 0:
delete_node_dict[node] = idx
# log max use
else:
node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()]
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == "placeholder" and get_node_shape(n) is not None:
new_active.append(n.name)
for i in new_active:
if i not in active_list and get_node_shape(n) is not None:
active_list.append(i)
def _remove_deactive_node(self,
user_idx: int,
user: Node,
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
"""
remove deactivate nodes from active nodes
"""
if kept_nodes is None:
kept_nodes = []
if user.op in ("output",):
return
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
delete_size = 0
delete_node = []
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
if to_keep is not None:
keep_list = []
for n in nodes_to_delete:
if n.name in to_keep:
keep_list.append(n)
for n in keep_list:
if n in nodes_to_delete:
nodes_to_delete.remove(n)
if len(nodes_to_delete):
out_node = [self._get_output_node(i) for i in nodes_to_delete]
delete_size = sum([i[0] for i in out_node])
for i in range(len(out_node)):
if out_node[i][0] > 0:
delete_node.append(out_node[i][1][0])
elif nodes_to_delete[i].op == "placeholder":
delete_node.append(nodes_to_delete[i].name)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return delete_size, delete_node
for node in list(active_nodes.keys()):
# dont delete kept nodes
if node in kept_nodes:
continue
# should be deleted
if delete_node_dict[node] <= user_idx:
active_nodes.pop(node)
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
for i in delete_node:
if i in active_list:
active_list.remove(i)
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
out_node = [self._get_output_node(i) for i in nodes_to_delete]
delete_size = sum([i[0] for i in out_node])
return delete_size
def _get_last_usr(self, nodes):
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return user_to_last_uses
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
def _get_tmp_memory(self, node, not_contiguous_list, delete=False):
mem = 0
not_contiguous_ops = ["permute"]
inherit_contiguous_ops = ["transpose", "view"]
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
for n in node.args:
if n in not_contiguous_list:
# matmul won't change origin tensor, but create a tmp copy
mem += self._get_output_node_size(n)
mem += self._get_node_size(n)
elif node.op == "call_module":
for n in node.args:
if n in not_contiguous_list:
@ -129,31 +112,7 @@ class EstimateMemory(object):
if chunk_dim is None:
return 1.0
else:
return float(chunk_size) / node_shape[chunk_dim]
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if user.op in ("placeholder", "output"):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
delete_size = 0
for n in nodes_to_delete:
if n.name in chunk_inputs_names:
continue
delete_size += self._get_output_node_size(n) * chunk_ratio
return delete_size
def _print_mem_log(self, log, nodes, title=None):
if title:
print(title)
for idx, (l, n) in enumerate(zip(log, nodes)):
print("%s:%.2f \t" % (n.name, l), end="")
if (idx + 1) % 3 == 0:
print("")
print("\n")
return chunk_size / float(node_shape[chunk_dim])
def _print_compute_op_mem_log(self, log, nodes, title=None):
if title:
@ -168,12 +127,22 @@ class EstimateMemory(object):
print("")
print("\n")
def estimate_chunk_inference_mem(
self,
node_list: List,
chunk_infos=None,
print_mem=False,
):
def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List:
"""
add active nodes from nodes
"""
for n in nodes:
self._add_active_node(n, active_nodes, 1)
def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float:
"""
sum all memory of active nodes
"""
out = [i for i in active_nodes.values()]
out = sum(out)
return out
def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False):
"""
Estimate inference memory with chunk
@ -191,18 +160,17 @@ class EstimateMemory(object):
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []
active_node_list = []
active_node_list_log = []
active_nodes = {}
active_nodes_log = []
not_contiguous_list = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
node_mgr = NodeMgr(node_list)
delete_node_dict = self._build_delete_node_dict(node_mgr)
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_names = []
chunk_inputs_all = []
if use_chunk:
chunk_regions = [i["region"] for i in chunk_infos]
@ -210,30 +178,30 @@ class EstimateMemory(object):
chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_list):
for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx])
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node,
chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx],
)
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx])
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
act_memory = self._get_memory_from_active_nodes(active_nodes)
# if node is placeholder, just add the size of the node
if node.op == "placeholder":
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
act_memory_peak_log.append(act_memory)
# skip output
elif node.op == "output":
@ -241,83 +209,32 @@ class EstimateMemory(object):
# no change for non compute node
elif is_non_memory_node(node):
act_memory_peak_log.append(act_memory)
# node is a compute op
# calculate tmp, output node and delete node memory
# node is a compute op, calculate tmp
else:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
(1024**2))
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if chunk_within:
act_memory -= self._get_chunk_delete_node_size(
node,
user_to_last_uses_no_free_var,
chunk_ratio,
chunk_inputs_names,
) / (1024**2)
else:
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
chunk_inputs_names) / (1024**2)
act_memory_peak_log.append(act_memory + tmp_memory)
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
# remove_deactive_node
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all)
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
act_memory -= self._get_chunk_inputs_size(
chunk_inputs[chunk_region_idx],
chunk_inputs_non_chunk[chunk_region_idx],
node_list,
chunk_regions[chunk_region_idx][1],
) / (1024**2)
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
act_memory = self._get_memory_from_active_nodes(active_nodes)
act_memory_after_node_log.append(act_memory)
active_node_list_log.append(copy.deepcopy(active_node_list))
active_nodes_log.append(active_nodes.copy())
if print_mem:
print("with chunk" if use_chunk else "without chunk")
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_compute_op_mem_log(
# act_memory_after_node_log, node_list, "after"
# )
self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak")
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
def get_active_nodes(self, node_list: List) -> List:
"""
Get active nodes for every node
Args:
node_list (List): _description_
Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
for _, node in enumerate(node_list):
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
active_node_list_log.append(copy.deepcopy(active_node_list))
return active_node_list_log
return act_memory_peak_log, act_memory_after_node_log, active_nodes_log

View File

@ -42,10 +42,11 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.max_memory = max_memory
self.print_progress = print_progress
self.node_mgr = NodeMgr(gm)
self.node_mgr = NodeMgr(list(gm.graph.nodes))
self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory(self.node_mgr)
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
@ -63,45 +64,46 @@ class SearchChunk(object):
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1] + 1
if cur_node_idx >= len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.set_active_nodes(active_nodes)
self.trace_indice.trace_indice()
def _find_peak_node(self, mem_peak: List) -> int:
def _find_peak_region(self, mem_peak: List) -> int:
"""
find peak node, along with its neighbour nodes exceeds max mem
"""
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
peak_region = [max_idx, max_idx]
if self.max_memory is None:
return peak_region
def _get_free_var_idx(self) -> List:
"""
Get free var index
# to left
count = 0
for i in range(max_idx - 1, -1, -1):
if mem_peak[i] > self.max_memory:
peak_region[0] = i
else:
count += 1
if count >= 3:
break
# to right
count = 0
for i in range(max_idx + 1, len(mem_peak) - 1):
if mem_peak[i] > self.max_memory:
peak_region[1] = i
count = 0
else:
count += 1
if count >= 3:
break
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.node_mgr.get_node_list()):
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
return peak_region
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
@ -119,50 +121,24 @@ class SearchChunk(object):
# check if peak node already in chunkinfo
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_node_idx <= i["region"][1]:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
return None
free_vars = self._get_free_var_idx()
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# normal search
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
window_size = 100
# search min for start
min_num = 1e4
for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e4
for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
break
# if normal search fails, use approximate search
if (chunk_region_end - chunk_region_start) > 250:
window_size = 100
# search min for start
min_num = 1e3
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e3
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
# avoid chunk regions overlap
if chunk_regions is not None:
@ -214,7 +190,7 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
"""
Search every possible region within the max chunk region.
@ -235,8 +211,8 @@ class SearchChunk(object):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
@ -270,13 +246,12 @@ class SearchChunk(object):
Returns:
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
peak_region = self._find_peak_region(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region

View File

@ -24,29 +24,16 @@ class SelectChunk(object):
else:
self.stratge = "min_memory"
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
elif self.stratge == "fit_memory":
best_region = self._select_fit_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
else:
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
@ -63,17 +50,14 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
@ -141,8 +125,7 @@ class SelectChunk(object):
count += 1
return count
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:

View File

@ -33,7 +33,6 @@ class TraceIndice(object):
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self) -> List:
@ -50,8 +49,7 @@ class TraceIndice(object):
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
def set_active_nodes(self, active_node_list: List) -> None:
self.active_node_list = active_node_list
def _add_indice(self) -> int:
@ -731,23 +729,35 @@ class TraceIndice(object):
dim_from.reverse()
# search view list
for view_node, view_dict in self.indice_view_list.items():
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# for view_node, view_dict in self.indice_view_list.items():
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
# and view_dict["dim_from"] == dim_to):
# # inheirt indice from current node
# if len_diff == 1:
# if origin_shape[dim_from[0]] == 1:
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
# elif origin_shape[dim_from[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# elif len_diff == -1:
# if target_shape[dim_to[0]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
# elif target_shape[dim_to[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# # inherid indice from input node of last view
# for dim_to_i in dim_to:
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inheirt indice from current node
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# log view, not used now
view_dict = {
@ -762,32 +772,22 @@ class TraceIndice(object):
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
trace_barrier = max(node_idx - 100, 0)
active_nodes = self.active_node_list[trace_barrier]
active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
trace = self.indice_trace_list[node_idx]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_barrier and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_mgr.get_node_list()):

View File

@ -11,8 +11,8 @@ logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, gm) -> None:
self._node_list = list(gm.graph.nodes)
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
self._set_node_dict()
@ -76,6 +76,8 @@ def flat_list(inputs: Any) -> List:
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(flat_list(i))
elif isinstance(i, dict):
res.extend(flat_list(list(i.keys())))
else:
res.append(i)
return res
@ -135,13 +137,6 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
return is_non_compute_node_except_placeholder(node)
def find_node_idx(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items():
for n in value:

View File

@ -61,7 +61,7 @@ def _benchmark_evoformer_stack_gm(
# bench
mem = _benchmark_memory(gm, inputs)
speed = _benchmark_speed(gm, inputs)
print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed))
def _benchmark_evoformer_stack_origin(
@ -83,14 +83,15 @@ def _benchmark_evoformer_stack_origin(
# bench
mem = _benchmark_memory(model, inputs)
speed = _benchmark_speed(model, inputs)
print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed))
return mem
def _benchmark_memory(model, inputs):
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
model(*inputs)
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
return new_max_mem - now_mem
@ -108,13 +109,18 @@ def _benchmark_speed(model, inputs, loop=5):
return (time2 - time1) / loop
def benchmark_evoformer_stack():
def benchmark_evoformer_stack(data_args):
from test_autochunk_evoformer_stack import get_data, get_model
data_args = [128, 256]
print("")
_benchmark_evoformer_stack_origin(data_args, get_model, get_data)
_benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data)
_benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data)
print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1]))
max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data)
for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]:
try:
_benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data)
except RuntimeError as e:
if e.args[0] == 'Search failed. Try a larger memory threshold.':
break
except Exception as e:
raise e
_benchmark_evoformer_stack_gm(data_args, None, get_model, get_data)
@ -128,4 +134,7 @@ if __name__ == "__main__":
port=free_port(),
backend="nccl",
)
benchmark_evoformer_stack()
benchmark_evoformer_stack((256, 256))
benchmark_evoformer_stack((256, 512))
benchmark_evoformer_stack((256, 1024))
benchmark_evoformer_stack((256, 1280))

View File

@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict:
return {
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
(25, 50)],
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
24: [(120, 123)],
None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184),
(140, 145), (162, 163), (203, 204)],
20: [(120, 123), (232, 237), (277, 282), (305, 306)],
24: [(122, 123)],
}

View File

@ -53,15 +53,6 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
return meta_args, concrete_args
def get_chunk_target() -> Dict:
return {
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
(36, 46)],
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
24: [(128, 131)],
}
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
@ -75,7 +66,6 @@ def test_extramsa_block(data_args, max_memory):
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
)
mp.spawn(run_func, nprocs=1)
@ -87,7 +77,6 @@ if __name__ == "__main__":
max_memory=None,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,

View File

@ -95,7 +95,7 @@ def _benchmark_memory(model, inputs):
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
model(*inputs)
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
return new_max_mem - now_mem
@ -116,8 +116,7 @@ def _benchmark_speed(model, inputs, loop=5):
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
model = GPT2Model
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head)
config.max_position_embeddings = seq
config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head)
model = model(config=config)
shape = [batch, seq]
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))

View File

@ -44,20 +44,19 @@ def test_autochunk_gpt(model, shape, max_memory):
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
)
run_test(rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
eval_mem=False)

View File

@ -24,6 +24,7 @@ def assert_codegen_run(
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
eval_mem: bool = False,
) -> List[Dict]:
meta_args, concrete_args, sequence = data
if concrete_args is None:
@ -39,12 +40,11 @@ def assert_codegen_run(
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
)
codegen = AutoChunkCodeGen(meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
eval_mem=eval_mem)
chunks = codegen.chunk_infos
# trace and recompile
@ -108,6 +108,7 @@ def run_test(
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
@ -122,15 +123,14 @@ def run_test(
)
# build model and input
chunks = assert_codegen_run(
model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = assert_codegen_run(model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
eval_mem=eval_mem)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]