mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update testpull/3057/head
parent
b51bfec357
commit
2ca9728cbb
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -216,14 +216,13 @@ def _add_node_slice(
|
|||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
):
|
||||
def emit_code_with_chunk(body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
|
@ -260,6 +259,9 @@ def emit_code_with_chunk(
|
|||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
||||
if eval_mem:
|
||||
body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")
|
||||
|
||||
while node_idx < len(node_list):
|
||||
node = node_list[node_idx]
|
||||
|
||||
|
@ -289,10 +291,18 @@ def emit_code_with_chunk(
|
|||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
if eval_mem:
|
||||
body.append(
|
||||
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
if eval_mem:
|
||||
body.append(
|
||||
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
|
@ -312,8 +322,10 @@ if AUTOCHUNK_AVAILABLE:
|
|||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False) -> None:
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.eval_mem = eval_mem
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
|
@ -511,14 +523,8 @@ if AUTOCHUNK_AVAILABLE:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(
|
||||
body,
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.search_chunk,
|
||||
self.chunk_infos,
|
||||
)
|
||||
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
|
||||
self.eval_mem)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
|
|
@ -2,11 +2,11 @@ import copy
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
|
||||
from .utils import NodeMgr, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
|
@ -14,102 +14,85 @@ class EstimateMemory(object):
|
|||
Estimate memory with chunk
|
||||
"""
|
||||
|
||||
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||
self.node_mgr = node_mgr
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
def _get_node_size(self, x: Node) -> float:
|
||||
"""
|
||||
return node size in MB
|
||||
"""
|
||||
x = x.meta["tensor_meta"]
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
if not hasattr(x, "numel"):
|
||||
out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x])
|
||||
else:
|
||||
out = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
out = float(out) / 1024**2
|
||||
return out
|
||||
|
||||
def _get_output_node(self, n):
|
||||
out_size = activation_size(n.meta["fwd_out"])
|
||||
out_node = [n.name] if out_size > 0 else []
|
||||
return out_size, out_node
|
||||
def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None:
|
||||
"""
|
||||
add an active node and its shape to active node dict
|
||||
"""
|
||||
if get_node_shape(n) is None:
|
||||
return
|
||||
if n.op == "placeholder":
|
||||
return
|
||||
if n not in active_nodes:
|
||||
node_size = self._get_node_size(n) * chunk_ratio
|
||||
active_nodes[n] = node_size
|
||||
|
||||
def _get_output_node_size(self, n):
|
||||
return self._get_output_node(n)[0]
|
||||
def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
|
||||
"""
|
||||
build delete node dict, means node should be deleted at what time
|
||||
"""
|
||||
delete_node_dict = {}
|
||||
for idx, node in enumerate(node_mgr.get_node_list()):
|
||||
# skip non shape node
|
||||
if get_node_shape(node) is None:
|
||||
continue
|
||||
# dont remove free nodes
|
||||
elif node.op == "placeholder":
|
||||
delete_node_dict[node] = len(node_mgr.get_node_list())
|
||||
# node no user
|
||||
elif len(node.users) == 0:
|
||||
delete_node_dict[node] = idx
|
||||
# log max use
|
||||
else:
|
||||
node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()]
|
||||
delete_node_dict[node] = max(node_user_idx)
|
||||
return delete_node_dict
|
||||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list and get_node_shape(n) is not None:
|
||||
active_list.append(i)
|
||||
def _remove_deactive_node(self,
|
||||
user_idx: int,
|
||||
user: Node,
|
||||
active_nodes: List,
|
||||
delete_node_dict: List,
|
||||
kept_nodes: List = None) -> None:
|
||||
"""
|
||||
remove deactivate nodes from active nodes
|
||||
"""
|
||||
if kept_nodes is None:
|
||||
kept_nodes = []
|
||||
if user.op in ("output",):
|
||||
return
|
||||
|
||||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||
delete_size = 0
|
||||
delete_node = []
|
||||
if user.op not in ("output",):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
if to_keep is not None:
|
||||
keep_list = []
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
keep_list.append(n)
|
||||
for n in keep_list:
|
||||
if n in nodes_to_delete:
|
||||
nodes_to_delete.remove(n)
|
||||
if len(nodes_to_delete):
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
for i in range(len(out_node)):
|
||||
if out_node[i][0] > 0:
|
||||
delete_node.append(out_node[i][1][0])
|
||||
elif nodes_to_delete[i].op == "placeholder":
|
||||
delete_node.append(nodes_to_delete[i].name)
|
||||
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
|
||||
# delete_node.append(nodes_to_delete[i].name)
|
||||
return delete_size, delete_node
|
||||
for node in list(active_nodes.keys()):
|
||||
# dont delete kept nodes
|
||||
if node in kept_nodes:
|
||||
continue
|
||||
# should be deleted
|
||||
if delete_node_dict[node] <= user_idx:
|
||||
active_nodes.pop(node)
|
||||
|
||||
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
|
||||
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
|
||||
|
||||
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
|
||||
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||
for i in delete_node:
|
||||
if i in active_list:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
return delete_size
|
||||
|
||||
def _get_last_usr(self, nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
return user_to_last_uses
|
||||
|
||||
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
|
||||
def _get_tmp_memory(self, node, not_contiguous_list, delete=False):
|
||||
mem = 0
|
||||
not_contiguous_ops = ["permute"]
|
||||
inherit_contiguous_ops = ["transpose", "view"]
|
||||
|
||||
if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
mem += self._get_output_node_size(n)
|
||||
mem += self._get_node_size(n)
|
||||
elif node.op == "call_module":
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
|
@ -129,31 +112,7 @@ class EstimateMemory(object):
|
|||
if chunk_dim is None:
|
||||
return 1.0
|
||||
else:
|
||||
return float(chunk_size) / node_shape[chunk_dim]
|
||||
|
||||
def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
|
||||
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
|
||||
# return 0
|
||||
if user.op in ("placeholder", "output"):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(user.users) == 0:
|
||||
nodes_to_delete.append(user)
|
||||
delete_size = 0
|
||||
for n in nodes_to_delete:
|
||||
if n.name in chunk_inputs_names:
|
||||
continue
|
||||
delete_size += self._get_output_node_size(n) * chunk_ratio
|
||||
return delete_size
|
||||
|
||||
def _print_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
print("%s:%.2f \t" % (n.name, l), end="")
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
return chunk_size / float(node_shape[chunk_dim])
|
||||
|
||||
def _print_compute_op_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
|
@ -168,12 +127,22 @@ class EstimateMemory(object):
|
|||
print("")
|
||||
print("\n")
|
||||
|
||||
def estimate_chunk_inference_mem(
|
||||
self,
|
||||
node_list: List,
|
||||
chunk_infos=None,
|
||||
print_mem=False,
|
||||
):
|
||||
def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List:
|
||||
"""
|
||||
add active nodes from nodes
|
||||
"""
|
||||
for n in nodes:
|
||||
self._add_active_node(n, active_nodes, 1)
|
||||
|
||||
def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float:
|
||||
"""
|
||||
sum all memory of active nodes
|
||||
"""
|
||||
out = [i for i in active_nodes.values()]
|
||||
out = sum(out)
|
||||
return out
|
||||
|
||||
def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False):
|
||||
"""
|
||||
Estimate inference memory with chunk
|
||||
|
||||
|
@ -191,18 +160,17 @@ class EstimateMemory(object):
|
|||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
active_nodes = {}
|
||||
active_nodes_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
node_mgr = NodeMgr(node_list)
|
||||
delete_node_dict = self._build_delete_node_dict(node_mgr)
|
||||
|
||||
use_chunk = True if chunk_infos is not None else False
|
||||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_inputs_names = []
|
||||
chunk_inputs_all = []
|
||||
|
||||
if use_chunk:
|
||||
chunk_regions = [i["region"] for i in chunk_infos]
|
||||
|
@ -210,30 +178,30 @@ class EstimateMemory(object):
|
|||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
for idx, node in enumerate(node_mgr.get_node_list()):
|
||||
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx])
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
chunk_node_dim[chunk_region_idx],
|
||||
chunk_sizes[chunk_region_idx],
|
||||
)
|
||||
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
|
||||
chunk_sizes[chunk_region_idx])
|
||||
|
||||
# add current node as active node
|
||||
self._add_active_node(node, active_nodes, chunk_ratio)
|
||||
act_memory = self._get_memory_from_active_nodes(active_nodes)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == "placeholder":
|
||||
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == "output":
|
||||
|
@ -241,83 +209,32 @@ class EstimateMemory(object):
|
|||
# no change for non compute node
|
||||
elif is_non_memory_node(node):
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# node is a compute op
|
||||
# calculate tmp, output node and delete node memory
|
||||
# node is a compute op, calculate tmp
|
||||
else:
|
||||
# forward memory
|
||||
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
|
||||
act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
|
||||
act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
|
||||
(1024**2))
|
||||
# delete unused vars not in chunk_input_list
|
||||
# we can't delete input nodes until chunk ends
|
||||
if chunk_within:
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node,
|
||||
user_to_last_uses_no_free_var,
|
||||
chunk_ratio,
|
||||
chunk_inputs_names,
|
||||
) / (1024**2)
|
||||
else:
|
||||
act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
|
||||
chunk_inputs_names) / (1024**2)
|
||||
act_memory_peak_log.append(act_memory + tmp_memory)
|
||||
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
# remove_deactive_node
|
||||
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all)
|
||||
|
||||
# if node in chunk end nodes, restore chunk settings
|
||||
if use_chunk and idx in chunk_ends:
|
||||
act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
|
||||
act_memory -= self._get_chunk_inputs_size(
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
node_list,
|
||||
chunk_regions[chunk_region_idx][1],
|
||||
) / (1024**2)
|
||||
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
|
||||
chunk_within = False
|
||||
chunk_ratio = 1
|
||||
chunk_region_idx = None
|
||||
|
||||
act_memory = self._get_memory_from_active_nodes(active_nodes)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
active_nodes_log.append(active_nodes.copy())
|
||||
|
||||
if print_mem:
|
||||
print("with chunk" if use_chunk else "without chunk")
|
||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_compute_op_mem_log(
|
||||
# act_memory_after_node_log, node_list, "after"
|
||||
# )
|
||||
self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak")
|
||||
|
||||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||
|
||||
def get_active_nodes(self, node_list: List) -> List:
|
||||
"""
|
||||
Get active nodes for every node
|
||||
|
||||
Args:
|
||||
node_list (List): _description_
|
||||
|
||||
Returns:
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
for _, node in enumerate(node_list):
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
return active_node_list_log
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_nodes_log
|
||||
|
|
|
@ -42,10 +42,11 @@ class SearchChunk(object):
|
|||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.max_memory = max_memory
|
||||
self.print_progress = print_progress
|
||||
self.node_mgr = NodeMgr(gm)
|
||||
self.node_mgr = NodeMgr(list(gm.graph.nodes))
|
||||
self.trace_indice = TraceIndice(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
|
||||
|
@ -63,45 +64,46 @@ class SearchChunk(object):
|
|||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1] + 1
|
||||
if cur_node_idx >= len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
# nothing to limit for the first range
|
||||
max_chunk_region_list = max_chunk_region_list[1:]
|
||||
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
||||
|
||||
active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
|
||||
# set trace range and do the trace
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start tracing indice")
|
||||
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
||||
self.trace_indice.set_active_nodes(active_nodes)
|
||||
self.trace_indice.trace_indice()
|
||||
|
||||
def _find_peak_node(self, mem_peak: List) -> int:
|
||||
def _find_peak_region(self, mem_peak: List) -> int:
|
||||
"""
|
||||
find peak node, along with its neighbour nodes exceeds max mem
|
||||
"""
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
peak_region = [max_idx, max_idx]
|
||||
if self.max_memory is None:
|
||||
return peak_region
|
||||
|
||||
def _get_free_var_idx(self) -> List:
|
||||
"""
|
||||
Get free var index
|
||||
# to left
|
||||
count = 0
|
||||
for i in range(max_idx - 1, -1, -1):
|
||||
if mem_peak[i] > self.max_memory:
|
||||
peak_region[0] = i
|
||||
else:
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
# to right
|
||||
count = 0
|
||||
for i in range(max_idx + 1, len(mem_peak) - 1):
|
||||
if mem_peak[i] > self.max_memory:
|
||||
peak_region[1] = i
|
||||
count = 0
|
||||
else:
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
Returns:
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.node_mgr.get_node_list()):
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
return peak_region
|
||||
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
|
@ -119,50 +121,24 @@ class SearchChunk(object):
|
|||
# check if peak node already in chunkinfo
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
if i["region"][0] < peak_node_idx <= i["region"][1]:
|
||||
if i["region"][0] < peak_region[0] <= i["region"][1] or \
|
||||
i["region"][0] < peak_region[1] <= i["region"][1]:
|
||||
return None
|
||||
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# normal search
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node_idx, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node_idx, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
window_size = 100
|
||||
# search min for start
|
||||
min_num = 1e4
|
||||
for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_start = i
|
||||
# search min for end
|
||||
min_num = 1e4
|
||||
for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
# if normal search fails, use approximate search
|
||||
if (chunk_region_end - chunk_region_start) > 250:
|
||||
window_size = 100
|
||||
# search min for start
|
||||
min_num = 1e3
|
||||
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_start = i
|
||||
# search min for end
|
||||
min_num = 1e3
|
||||
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_end = i
|
||||
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
|
@ -214,7 +190,7 @@ class SearchChunk(object):
|
|||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
|
@ -235,8 +211,8 @@ class SearchChunk(object):
|
|||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
|
||||
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||
self.node_mgr.get_node_by_idx(end_idx)):
|
||||
|
@ -270,13 +246,12 @@ class SearchChunk(object):
|
|||
Returns:
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
||||
peak_region = self._find_peak_region(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
||||
max_chunk_region, mem_peak)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
|
|
|
@ -24,29 +24,16 @@ class SelectChunk(object):
|
|||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
|
||||
if self.stratge == "min_memory":
|
||||
best_region = self._select_min_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
|
||||
elif self.stratge == "fit_memory":
|
||||
best_region = self._select_fit_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
|
||||
# stop chunk if max memory satisfy memory limit
|
||||
if max(mem_peak) < self.max_memory:
|
||||
return None
|
||||
|
@ -63,17 +50,14 @@ class SelectChunk(object):
|
|||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append({
|
||||
|
@ -141,8 +125,7 @@ class SelectChunk(object):
|
|||
count += 1
|
||||
return count
|
||||
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
|
|
|
@ -33,7 +33,6 @@ class TraceIndice(object):
|
|||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self) -> List:
|
||||
|
@ -50,8 +49,7 @@ class TraceIndice(object):
|
|||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
|
||||
self.trace_range = trace_range
|
||||
def set_active_nodes(self, active_node_list: List) -> None:
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self) -> int:
|
||||
|
@ -731,23 +729,35 @@ class TraceIndice(object):
|
|||
dim_from.reverse()
|
||||
|
||||
# search view list
|
||||
for view_node, view_dict in self.indice_view_list.items():
|
||||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
and view_dict["dim_from"] == dim_to):
|
||||
# inheirt indice from current node
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# inherid indice from input node of last view
|
||||
for dim_to_i in dim_to:
|
||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
# for view_node, view_dict in self.indice_view_list.items():
|
||||
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
# and view_dict["dim_from"] == dim_to):
|
||||
# # inheirt indice from current node
|
||||
# if len_diff == 1:
|
||||
# if origin_shape[dim_from[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
# elif origin_shape[dim_from[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# elif len_diff == -1:
|
||||
# if target_shape[dim_to[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
# elif target_shape[dim_to[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# # inherid indice from input node of last view
|
||||
# for dim_to_i in dim_to:
|
||||
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inheirt indice from current node
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
|
@ -762,32 +772,22 @@ class TraceIndice(object):
|
|||
"""
|
||||
clear too far trace to speed up computation
|
||||
"""
|
||||
trace_range = None
|
||||
for i in range(len(self.trace_range)):
|
||||
if self.trace_range[i][1] == node_idx:
|
||||
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
|
||||
break
|
||||
if self.trace_range[i][1] > node_idx:
|
||||
break
|
||||
if trace_range is None:
|
||||
return
|
||||
trace_barrier = max(node_idx - 100, 0)
|
||||
active_nodes = self.active_node_list[trace_barrier]
|
||||
active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
trace = self.indice_trace_list[node_idx]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_barrier and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self) -> None:
|
||||
for idx, node in enumerate(self.node_mgr.get_node_list()):
|
||||
|
|
|
@ -11,8 +11,8 @@ logger = get_dist_logger()
|
|||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, gm) -> None:
|
||||
self._node_list = list(gm.graph.nodes)
|
||||
def __init__(self, nodes_list: List[Node]) -> None:
|
||||
self._node_list = nodes_list
|
||||
self._node_dict = {}
|
||||
self._set_node_dict()
|
||||
|
||||
|
@ -76,6 +76,8 @@ def flat_list(inputs: Any) -> List:
|
|||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(flat_list(i))
|
||||
elif isinstance(i, dict):
|
||||
res.extend(flat_list(list(i.keys())))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
|
@ -135,13 +137,6 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
|||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_node_idx(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
|
|
|
@ -61,7 +61,7 @@ def _benchmark_evoformer_stack_gm(
|
|||
# bench
|
||||
mem = _benchmark_memory(gm, inputs)
|
||||
speed = _benchmark_speed(gm, inputs)
|
||||
print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
|
||||
print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed))
|
||||
|
||||
|
||||
def _benchmark_evoformer_stack_origin(
|
||||
|
@ -83,14 +83,15 @@ def _benchmark_evoformer_stack_origin(
|
|||
# bench
|
||||
mem = _benchmark_memory(model, inputs)
|
||||
speed = _benchmark_speed(model, inputs)
|
||||
print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
|
||||
print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed))
|
||||
return mem
|
||||
|
||||
|
||||
def _benchmark_memory(model, inputs):
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
model(*inputs)
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
return new_max_mem - now_mem
|
||||
|
||||
|
@ -108,13 +109,18 @@ def _benchmark_speed(model, inputs, loop=5):
|
|||
return (time2 - time1) / loop
|
||||
|
||||
|
||||
def benchmark_evoformer_stack():
|
||||
def benchmark_evoformer_stack(data_args):
|
||||
from test_autochunk_evoformer_stack import get_data, get_model
|
||||
data_args = [128, 256]
|
||||
print("")
|
||||
_benchmark_evoformer_stack_origin(data_args, get_model, get_data)
|
||||
_benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data)
|
||||
_benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data)
|
||||
print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1]))
|
||||
max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data)
|
||||
for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]:
|
||||
try:
|
||||
_benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data)
|
||||
except RuntimeError as e:
|
||||
if e.args[0] == 'Search failed. Try a larger memory threshold.':
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
_benchmark_evoformer_stack_gm(data_args, None, get_model, get_data)
|
||||
|
||||
|
||||
|
@ -128,4 +134,7 @@ if __name__ == "__main__":
|
|||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
benchmark_evoformer_stack()
|
||||
benchmark_evoformer_stack((256, 256))
|
||||
benchmark_evoformer_stack((256, 512))
|
||||
benchmark_evoformer_stack((256, 1024))
|
||||
benchmark_evoformer_stack((256, 1280))
|
||||
|
|
|
@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
|
||||
(25, 50)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
|
||||
24: [(120, 123)],
|
||||
None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184),
|
||||
(140, 145), (162, 163), (203, 204)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306)],
|
||||
24: [(122, 123)],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -53,15 +53,6 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
return meta_args, concrete_args
|
||||
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
|
||||
(36, 46)],
|
||||
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
|
||||
24: [(128, 131)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
|
@ -75,7 +66,6 @@ def test_extramsa_block(data_args, max_memory):
|
|||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
@ -87,7 +77,6 @@ if __name__ == "__main__":
|
|||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
|
|
|
@ -95,7 +95,7 @@ def _benchmark_memory(model, inputs):
|
|||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
|
||||
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
model(*inputs)
|
||||
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
|
||||
return new_max_mem - now_mem
|
||||
|
||||
|
@ -116,8 +116,7 @@ def _benchmark_speed(model, inputs, loop=5):
|
|||
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
|
||||
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
|
||||
model = GPT2Model
|
||||
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head)
|
||||
config.max_position_embeddings = seq
|
||||
config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head)
|
||||
model = model(config=config)
|
||||
shape = [batch, seq]
|
||||
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))
|
||||
|
|
|
@ -44,20 +44,19 @@ def test_autochunk_gpt(model, shape, max_memory):
|
|||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
run_test(rank=0,
|
||||
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
eval_mem=False)
|
||||
|
|
|
@ -24,6 +24,7 @@ def assert_codegen_run(
|
|||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
eval_mem: bool = False,
|
||||
) -> List[Dict]:
|
||||
meta_args, concrete_args, sequence = data
|
||||
if concrete_args is None:
|
||||
|
@ -39,12 +40,11 @@ def assert_codegen_run(
|
|||
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
codegen = AutoChunkCodeGen(meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
eval_mem=eval_mem)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
|
@ -108,6 +108,7 @@ def run_test(
|
|||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
model = model(config=config)
|
||||
|
@ -122,15 +123,14 @@ def run_test(
|
|||
)
|
||||
|
||||
# build model and input
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = assert_codegen_run(model,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
eval_mem=eval_mem)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
|
|
Loading…
Reference in New Issue