[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
pull/2540/head
oahzxl 2023-02-01 13:18:51 +08:00 committed by GitHub
parent f477a14f4a
commit 05671fcb42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 428 additions and 258 deletions

View File

@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
context (str): generated str
"""
input_node = chunk_input[0]
out_shape = get_node_shape(chunk_output)
out_str = str(list(out_shape))
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
(out_str, input_node.name, input_node.name, chunk_size))
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
context = ""
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) == "split":
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_ouput_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
chunk_outputs: Node,
chunk_outputs_dim: int,
node_list: List[Node],
) -> str:
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
"""
Generate chunk loop end
@ -102,22 +108,13 @@ def _gen_loop_end(
Returns:
context (str): generated str
"""
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice,
chunk_outputs_name,
chunk_outputs_name,
)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
context = "chunk_size = None"
# determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
context += "; %s = None" % chunk_input.name
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
context += "\n"
return context
@ -158,7 +155,7 @@ def _replace_ones_like(
add chunk slice for new tensor op such as ones like
"""
if "ones_like" in node.name:
meta_node = search_chunk.trace_indice.node_list[node_idx]
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
@ -169,21 +166,37 @@ def _replace_ones_like(
return body
def _replace_input_node(
chunk_inputs: List[Node],
def _add_node_slice(
chunk_nodes: List[Node],
region_idx: int,
chunk_inputs_dim: Dict,
chunk_nodes_dim: Dict,
node_idx: int,
body: List[str],
node: Node,
) -> List[str]:
"""
add chunk slice for input nodes
"""
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
# inputs node
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
if get_node_name(chunk_node) == "split":
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
else:
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
return body
@ -222,7 +235,8 @@ def emit_code_with_chunk(
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
@ -248,7 +262,9 @@ def emit_code_with_chunk(
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size
@ -263,13 +279,8 @@ def emit_code_with_chunk(
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
node_list,
))
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
within_chunk_region = False
node_idx += 1

View File

@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
class EstimateMemory(object):
@ -14,8 +14,8 @@ class EstimateMemory(object):
Estimate memory with chunk
"""
def __init__(self) -> None:
pass
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
def _get_meta_node_size(self, x):
x = x.meta["tensor_meta"]
@ -78,7 +78,7 @@ class EstimateMemory(object):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
@ -212,7 +212,7 @@ class EstimateMemory(object):
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
@ -221,7 +221,7 @@ class EstimateMemory(object):
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
# determine chunk ratio for current node
if chunk_within:

View File

@ -1,5 +1,5 @@
from .trace_indice import TraceIndice
from .utils import find_idx_by_name
from .utils import NodeMgr
class ReorderGraph(object):
@ -7,31 +7,27 @@ class ReorderGraph(object):
Reorder node list and indice trace list
"""
def __init__(self, trace_indice: TraceIndice) -> None:
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.all_reorder_map = {
i: i for i in range(len(self.trace_indice.indice_trace_list))
}
self.node_mgr = node_mgr
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.trace_indice.node_list)
for i in chunk_prepose_nodes
]
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
if n in chunk_prepose_nodes:
continue
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
n_idx = self.node_mgr.find_node_idx(n)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
@ -44,7 +40,7 @@ class ReorderGraph(object):
chunk_info["region"][1],
)
new_inputs_dim = []
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
for _, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {}
for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v
@ -57,16 +53,14 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
self.trace_indice.node_list = new_node_list
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
self.node_mgr.update_node_list(new_node_list)
def _reorder_idx_trace(self, reorder_map):
# reorder list
new_idx_trace_list = [
None for _ in range(len(self.trace_indice.indice_trace_list))
]
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list

View File

@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
@ -49,15 +50,17 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.estimate_memory = EstimateMemory()
self.node_mgr = NodeMgr(gm)
self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory(self.node_mgr)
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
self.node_mgr,
max_memory=max_memory,
)
@ -67,7 +70,7 @@ class SearchChunk(object):
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
@ -100,7 +103,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
for idx, n in enumerate(self.node_mgr.get_node_list()):
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
@ -164,6 +167,44 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
@ -178,7 +219,7 @@ class SearchChunk(object):
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
@ -188,11 +229,11 @@ class SearchChunk(object):
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
continue
# select free dim
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
@ -254,7 +295,7 @@ class SearchChunk(object):
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak
while True:
@ -267,7 +308,7 @@ class SearchChunk(object):
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
@ -277,5 +318,7 @@ class SearchChunk(object):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
return chunk_infos

View File

@ -1,7 +1,7 @@
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice
from .utils import is_non_compute_node
from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
@ -11,11 +11,13 @@ class SelectChunk(object):
trace_indice: TraceIndice,
estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph,
node_mgr: NodeMgr,
max_memory=None,
):
self.trace_indice = trace_indice
self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
@ -68,7 +70,7 @@ class SelectChunk(object):
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
@ -134,7 +136,7 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end):
count = 0
for i in self.trace_indice.node_list[start:end + 1]:
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
if not is_non_compute_node(i):
count += 1
return count
@ -161,7 +163,7 @@ class SelectChunk(object):
regions_dict_list = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]

View File

@ -4,9 +4,10 @@ from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
find_tensor_shape_node,
flat_list,
get_node_name,
get_node_shape,
@ -16,8 +17,9 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None:
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
"""
@ -31,7 +33,8 @@ class TraceFlow(object):
Returns:
bool: True if check pass
"""
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
# we use start_node_idx instead of real chunk index
start_node_idx = self.node_mgr.find_node_idx(start_node)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
@ -39,7 +42,7 @@ class TraceFlow(object):
if node_idx == start_node_idx and start_dim in node_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx:
if node_idx < start_node_idx:
return False
return False
@ -61,29 +64,12 @@ class TraceFlow(object):
return False
return True
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None
def _assgin_single_node_flow(
self,
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node: Node,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
@ -109,7 +95,7 @@ class TraceFlow(object):
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
arg_idx = self.node_mgr.find_node_idx(arg_node)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
@ -126,6 +112,11 @@ class TraceFlow(object):
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
# chunk shape should equal cur node
elif get_node_shape(arg_node)[arg_dim] != 1:
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
return False
else:
arg_dim = None
@ -150,7 +141,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@ -178,6 +169,7 @@ class TraceFlow(object):
arg,
start_idx,
end_idx,
cur_node,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
@ -194,7 +186,7 @@ class TraceFlow(object):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
@ -232,7 +224,7 @@ class TraceFlow(object):
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
input_node_idx = self.node_mgr.find_node_idx(input_node)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
@ -240,7 +232,7 @@ class TraceFlow(object):
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
user_idx = self.node_mgr.find_node_idx(user)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
@ -262,7 +254,7 @@ class TraceFlow(object):
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
@ -279,8 +271,11 @@ class TraceFlow(object):
for node, node_info in all_node_info.items():
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
if node not in all_node_info and node not in chunk_info["outputs"]:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
) # from last node to first node
prepose_nodes = []
@ -303,8 +298,7 @@ class TraceFlow(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
end_idx):
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
@ -328,13 +322,12 @@ class TraceFlow(object):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
return prepose_nodes
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
@ -345,34 +338,41 @@ class TraceFlow(object):
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
# only single ouput
if len(outputs) > 1:
return None
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
"inputs": [],
"inputs_non_chunk": [],
"inputs_dim": inputs_dim,
"outputs": outputs,
"outputs_dim": end_dim,
"inputs_dim": [],
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
"outputs_non_tensor": {},
"outputs_dim": [end_dim],
"node_chunk_dim": all_node_info,
"args": {},
}
# find chunk info for other outputs
if len(find_tensor_shape_node(outputs)) > 1:
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
if chunk_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info["inputs"] = inputs
chunk_info["inputs_dim"] = inputs_dim
# move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
@ -382,6 +382,63 @@ class TraceFlow(object):
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
output_legal = False
output_idx = self.node_mgr.find_node_idx(output)
# skip the origin output
if output_idx == end_idx:
continue
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
continue
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
if new_all_node_info is None:
continue
# check node info legal
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
output_legal = True
break
# not legal
if output_legal == False:
return None
return chunk_info
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
"""
check if there is conflict between new node info and old chunk info. If not, update old chunk info
"""
# check if conflict
overlap_flag = False
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
overlap_flag = True
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
return False
# if no overlap, we just consider them as prepose nodes, instead of new output
if overlap_flag == False:
return True
# update chunk info
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
@ -389,10 +446,17 @@ class TraceFlow(object):
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
if any(i == get_node_name(node) for i in ["reshape", "view"]):
if node in chunk_info["args"]["prepose_nodes"]:
continue
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
@ -409,45 +473,8 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""

View File

@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import (
find_first_tensor_arg,
find_idx_by_name,
flat_list,
get_module_node_name,
get_node_name,
get_node_shape,
)
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
class TraceIndice(object):
@ -35,8 +28,8 @@ class TraceIndice(object):
node_list (List)
"""
def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
@ -45,7 +38,7 @@ class TraceIndice(object):
def _init_indice_trace_list(self) -> List:
indice_trace_list = []
for n in self.node_list:
for n in self.node_mgr.get_node_list():
if get_node_shape(n) != None:
cur_trace = {
"indice": [None for _ in range(len(get_node_shape(n)))],
@ -99,7 +92,7 @@ class TraceIndice(object):
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
node_from_idx = self.node_mgr.find_node_idx(node_from)
if init:
node_to_trace_source[node_to_dim] = {}
# add dim to cur new source
@ -200,7 +193,7 @@ class TraceIndice(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict
@ -214,7 +207,7 @@ class TraceIndice(object):
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
@ -227,7 +220,7 @@ class TraceIndice(object):
Returns:
idx (list): idx of the node
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node: Node) -> List:
@ -239,7 +232,7 @@ class TraceIndice(object):
Returns:
compute (list): computed idx of the node.
"""
node_idx = find_idx_by_name(node.name, self.node_list)
node_idx = self.node_mgr.find_node_idx(node)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
@ -454,8 +447,6 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
for _ in range(len(get_node_shape(node.args[0]))):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx)
@ -702,21 +693,20 @@ class TraceIndice(object):
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
for dim_to_i in dim_to:
for dim_from_i in dim_from:
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from:
if origin_trace[i] in compute_log:
for j in dim_to:
self._mark_computation(node, node_idx, [j])
break
# log view, not used now
view_dict = {
"idx_from": [origin_trace[i] for i in dim_from],
@ -742,7 +732,7 @@ class TraceIndice(object):
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
@ -758,7 +748,7 @@ class TraceIndice(object):
dim_source.pop(k)
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_list):
for idx, node in enumerate(self.node_mgr.get_node_list()):
node_name = get_node_name(node)
if node.op == "placeholder":
self._assign_all_indice(node, idx)

View File

@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz
logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, gm) -> None:
self._node_list = list(gm.graph.nodes)
self._node_dict = {}
self._set_node_dict()
def _set_node_dict(self) -> None:
"""
create a dict {node_name: node_idx}
"""
self._node_dict.clear()
for idx, node in enumerate(self._node_list):
self._node_dict[node.name] = idx
def find_node_idx(self, node: Node) -> int:
"""
find node's index
"""
return self._node_dict[node.name]
def find_node_idx_by_name(self, node_name: str) -> int:
"""
find node's index
"""
return self._node_dict[node_name]
def get_node_by_idx(self, idx: int) -> Node:
"""
get a node by index
"""
return self._node_list[idx]
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
"""
get a slice of node by index
"""
return self._node_list[start:end]
def get_node_list(self) -> List:
"""
get full node list
"""
return self._node_list
def update_node_list(self, node_list: List) -> None:
"""
update node list, reset node dict
"""
self._node_list = node_list
self._set_node_dict()
def get_logger() -> Any:
return logger
@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool:
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True
if "getitem" in node.name:
if get_node_shape(node) is not None:
return False
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool:
def get_node_shape(node: Node) -> List:
if get_node_name(node) == "split":
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name: str, nodes_list: List) -> int:
def find_node_idx(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str:
else:
break
return node_name
def find_tensor_node(node_list: List[Node]) -> List[Node]:
"""
find tensor nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
return out
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
"""
find tensor and shape nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
out.append(node)
return out

View File

@ -23,6 +23,7 @@ def assert_codegen_run(
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
@ -61,13 +62,20 @@ def assert_codegen_run(
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert "chunk_size = None; " in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda()
with torch.no_grad():
out_gm = gm(*inputs)
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs)
out_gm = flat_list(out_gm)
out_model = flat_list(out_model)
@ -85,9 +93,10 @@ def run_test(
max_memory: int,
get_model: Any,
get_data: Any,
print_code: bool,
print_mem: bool,
print_progress: bool,
print_code: bool = False,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
@ -110,6 +119,7 @@ def run_test(
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_est_mem=print_est_mem,
print_progress=print_progress,
)

View File

@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict:
return {
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
24: [(118, 123)],
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
(25, 50)],
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
24: [(120, 123)],
}
@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@ -86,10 +84,12 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=24,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_est_mem=False,
print_progress=False,
)

View File

@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@ -81,7 +78,7 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=None,
get_model=get_model,
get_data=get_data,
print_code=False,

View File

@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict:
return {
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
(33, 46)],
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
24: [(126, 131)],
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
(36, 46)],
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
24: [(128, 131)],
}
@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
get_chunk_target=get_chunk_target,
)
mp.spawn(run_func, nprocs=1)
@ -86,7 +84,7 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=None,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,

View File

@ -17,8 +17,8 @@ from test_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 256
BATCH_SIZE = 1
SEQ_LENGTH = 512
def get_data(shape: tuple) -> Tuple[List, List]:
@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
def test_gpt(model, shape, max_memory):
@pytest.mark.parametrize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory):
run_func = partial(
run_test,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@ -59,7 +56,8 @@ if __name__ == "__main__":
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=True,
print_mem=True,
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
)

View File

@ -20,6 +20,7 @@ def assert_codegen_run(
model: Any,
data: tuple,
max_memory: int = None,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
@ -61,7 +62,7 @@ def assert_codegen_run(
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert "chunk_size = None; " in code
# assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
@ -69,26 +70,44 @@ def assert_codegen_run(
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs)
for k in out_model.keys():
if torch.is_tensor(out_gm[k]):
assert torch.equal(
out_model[k], out_gm[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
assert_allclose(out_model, out_gm)
return chunks
def assert_allclose(out_model: Any, out_gm: Any) -> None:
"""
assert allclose for out
"""
if isinstance(out_model, torch.Tensor):
assert torch.allclose(out_model, out_gm,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_model - out_gm))
elif isinstance(out_model, dict):
for k in out_model.keys():
assert_allclose(out_model[k], out_gm[k])
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
for i, j in zip(out_model, out_gm):
assert_allclose(i, j)
def run_test(
rank: int,
model: Any,
config: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
print_code: bool = False,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
@ -108,6 +127,7 @@ def run_test(
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
)
@ -119,5 +139,3 @@ def run_test(
str(chunk_found),
str(chunk_target),
)
gpc.destroy()