[autochunk] support transformer (#2526)

pull/2540/head
oahzxl 2023-01-31 16:00:06 +08:00 committed by GitHub
parent 6e0faa70e0
commit 63199c6687
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1214 additions and 1084 deletions

View File

@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
import torch
import colossalai
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE:
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
@ -272,7 +275,7 @@ def emit_code_with_chunk(
node_idx += 1
if CODEGEN_AVAILABLE:
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):

View File

@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import (
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class SearchChunk(object):
@ -114,6 +120,12 @@ class SearchChunk(object):
chunk_region_start (int)
chunk_region_end (int)
"""
# check if peak node already in chunkinfo
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_node_idx <= i["region"][1]:
return None
free_vars = self._get_free_var_idx()
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
@ -152,55 +164,6 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
if len(start_traces) > 1:
continue
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
if not self.trace_flow.check_index_duplicate(chunk_info):
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
@ -228,9 +191,8 @@ class SearchChunk(object):
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region

View File

@ -5,6 +5,7 @@ from .utils import is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
@ -17,13 +18,11 @@ class SelectChunk(object):
self.reorder_graph = reorder_graph
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
def _select_best_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(
possible_chunk_regions,
@ -44,9 +43,8 @@ class SelectChunk(object):
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
@ -63,33 +61,26 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@ -113,20 +104,13 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
+ 1
]
)
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@ -139,12 +123,9 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
)
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@ -153,14 +134,13 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end):
count = 0
for i in self.trace_indice.node_list[start : end + 1]:
for i in self.trace_indice.node_list[start:end + 1]:
if not is_non_compute_node(i):
count += 1
return count
def _select_min_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
@ -173,37 +153,31 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
regions_dict_list = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
regions_dict_list.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
best_region = regions_dict[best_region_idx]["chunk_info"]
best_region = regions_dict_list[best_region_idx]["chunk_info"]
if best_region is not None:
best_region["chunk_size"] = 1
return best_region
@ -216,9 +190,7 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
return False
return True

View File

@ -8,9 +8,9 @@ from .utils import (
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
flat_list,
get_node_name,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
@ -79,43 +79,6 @@ class TraceFlow(object):
return node_dim
return None
def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
duplicate_dims = []
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
duplicate_dim = []
duplicate_flag = False
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] in v:
duplicate_flag = True
duplicate_dim.append((k, v))
duplicate_dims.append(duplicate_dim)
if duplicate_flag:
count += 1
if count > 1:
if return_dim:
return False, duplicate_dims
else:
return False
if return_dim:
return True, None
else:
return True
def _assgin_single_node_flow(
self,
arg_node: Node,
@ -225,9 +188,12 @@ class TraceFlow(object):
if flow_flag == False:
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
if len(arg_list) >= 2:
# need to mark fix dim
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
@ -240,9 +206,8 @@ class TraceFlow(object):
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif "einsum" in cur_node.name:
pass
elif "matmul" in cur_node.name:
elif any(i == get_node_name(cur_node)
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
pass
else:
raise NotImplementedError()
@ -426,7 +391,7 @@ class TraceFlow(object):
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
if any(i == get_node_name(node) for i in ["reshape", "view"]):
reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@ -443,3 +408,62 @@ class TraceFlow(object):
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
return False
# must have users
if len(end_node.users) == 0:
return False
# check index source align
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
return False
# check index copmute
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
return False
return True

View File

@ -3,7 +3,14 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
from .utils import (
find_first_tensor_arg,
find_idx_by_name,
flat_list,
get_module_node_name,
get_node_name,
get_node_shape,
)
class TraceIndice(object):
@ -36,7 +43,7 @@ class TraceIndice(object):
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
def _init_indice_trace_list(self) -> List:
indice_trace_list = []
for n in self.node_list:
if get_node_shape(n) != None:
@ -54,7 +61,7 @@ class TraceIndice(object):
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
def _add_indice(self) -> int:
"""
Update the count and return it. To record the idx number.
@ -64,39 +71,30 @@ class TraceIndice(object):
self.indice_count += 1
return self.indice_count
def _del_dim(self, idx, dim_idx):
def _del_dim(self, idx: int, dim_idx: int) -> None:
"""
delete a dim for indice, compute and source
"""
self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx):
def _add_dim(self, node_idx: int, dim_idx: int) -> None:
"""
add a dim for indice, compute and source
"""
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
def _transform_indice(self, node, node_dim):
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to):
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
assert len(node_from_compute) == len(node_to_compute)
for i in range(len(node_from_compute)):
self._add_source(node_from, i, node_to, i)
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
def _add_source(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init=False,
) -> None:
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
@ -119,7 +117,50 @@ class TraceIndice(object):
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
def _transform_indice(self, node: Node, node_dim: int) -> int:
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init: bool = True,
) -> None:
"""
node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
"""
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
if init:
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
else:
for j in node_from_trace["compute"][node_from_dim]:
if j not in node_to_trace["compute"][node_to_dim]:
node_to_trace["compute"][node_to_dim].append(j)
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init)
def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None:
"""
inherit all dims with init
"""
# find indice just for assert length
node_from_indice = self._find_indice_trace_from_node(node_from)
node_to_indice = self._find_indice_trace_from_node(node_to)
assert len(node_from_indice) == len(node_to_indice)
for i in range(len(node_from_indice)):
self._inherit_indice(node_from, i, node_to, i, init=True)
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
"""
inheirt indice from node without init
"""
if exclude == None:
exclude = []
else:
@ -130,12 +171,9 @@ class TraceIndice(object):
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_indice(node_to, i) in exclude:
continue
self._add_source(node_from, i, node_to, i)
for j in node_from_compute[i]:
if j not in node_to_compute[i]:
node_to_compute[i].append(j)
self._inherit_indice(node_from, i, node_to, i, init=False)
def _mark_computation(self, node, idx, dim):
def _mark_computation(self, node: Node, idx: int, dim: int) -> None:
"""
Mark some dims of node as computed.
@ -152,7 +190,7 @@ class TraceIndice(object):
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
def _find_trace_from_node(self, node):
def _find_trace_from_node(self, node: Node) -> Dict:
"""
Find node idx and compute trace by the node.
@ -166,7 +204,7 @@ class TraceIndice(object):
node_dict = self.indice_trace_list[node_idx]
return node_dict
def _find_source_trace_from_node(self, node):
def _find_source_trace_from_node(self, node: Node) -> List:
"""
Find node source trace by the node.
@ -180,7 +218,7 @@ class TraceIndice(object):
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
def _find_indice_trace_from_node(self, node):
def _find_indice_trace_from_node(self, node) -> List:
"""
Find node idx trace by the node.
@ -192,7 +230,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node):
def _find_compute_trace_from_node(self, node: Node) -> List:
"""
Find node compute trace by the node.
@ -204,7 +242,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
"""
Assign node's trace as its input node.
@ -214,15 +252,9 @@ class TraceIndice(object):
"""
if input_node == None:
input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
self._inherit_all_indice(input_node, node)
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node: Node, node_idx: int):
def _assign_all_indice(self, node: Node, node_idx: int) -> None:
"""
Add new indice for all node's dims.
@ -238,7 +270,7 @@ class TraceIndice(object):
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node: Node, node_idx: int):
def _assign_transpose_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
@ -255,7 +287,7 @@ class TraceIndice(object):
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node: Node, node_idx: int):
def _assign_permute_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for permute op.
1. swap input's dim according to permute args
@ -272,7 +304,7 @@ class TraceIndice(object):
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node: Node, node_idx: int):
def _assign_linear_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
@ -293,7 +325,23 @@ class TraceIndice(object):
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node: Node, node_idx: int):
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for addmm op.
Args:
node (node)
node_idx (int)
"""
bias, input_node, weight = node.args
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(weight, 1, node, -1)
self._inherit_indice(bias, -1, node, -1)
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@ -310,7 +358,7 @@ class TraceIndice(object):
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
def _assign_layernorm_indice(self, node, idx):
@ -341,14 +389,13 @@ class TraceIndice(object):
for node_in in node.args:
if type(node_in) == type(node):
nodes_in.append(node_in)
self._mark_computation_from_node(node_in, node)
assert len(nodes_in) <= 2
self._inherit_more_indice_from_node(node_in, node)
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node)
self._inherit_more_indice_from_node(node_in, node)
def _assign_einsum_indice(self, node, idx):
"""
@ -365,7 +412,7 @@ class TraceIndice(object):
left, right = patterns.split("->")
left = left.split(",")
if '...' in right:
if "..." in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
@ -399,7 +446,22 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
def _assign_split_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for split op.
Args:
node (node)
node_idx (int)
"""
for _ in range(len(get_node_shape(node.args[0]))):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx)
self._add_dim(node_idx, dim_idx)
def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@ -416,18 +478,7 @@ class TraceIndice(object):
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node: Node, node_idx: int):
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for oneslike op.
1. assign new indice for all dim
@ -438,7 +489,7 @@ class TraceIndice(object):
"""
self._assign_all_indice(node, node_idx)
def _assign_cat_indice(self, node: Node, node_idx: int):
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for cat op.
@ -449,12 +500,12 @@ class TraceIndice(object):
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
self._inherit_more_indice_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)
def _assign_sum_indice(self, node: Node, node_idx: int):
def _assign_sum_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for sum op.
@ -466,11 +517,46 @@ class TraceIndice(object):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
self._inherit_more_indice_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_getitem_indice(self, node: Node, node_idx: int):
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for arange op.
Args:
node (node)
node_idx (int)
"""
self._assign_all_indice(node, node_idx)
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for tensor op.
Args:
node (node)
node_idx (int)
"""
if len(get_node_shape(node)) == 0:
return
else:
raise NotImplementedError()
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for embedding op.
Args:
node (node)
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, -1)
def _assign_getitem_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for getitem.
getitem can act like slice sometimes
@ -480,6 +566,19 @@ class TraceIndice(object):
node_idx (int)
"""
node_args = flat_list(node.args[1:])
# deal with split
if get_node_name(node.args[0]) == "split":
self._assign_indice_as_input(node, node_idx)
self._del_dim(node_idx, node.args[0].kwargs["dim"])
self._add_dim(node_idx, node.args[0].kwargs["dim"])
return
# skip non tensor
if get_node_shape(node) is None:
return
# find if slice
flag = False
for node_arg in node_args:
node_arg_str = str(node_arg)
@ -528,7 +627,7 @@ class TraceIndice(object):
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
@ -536,7 +635,7 @@ class TraceIndice(object):
3. determine changed dim, and assgin indice for generated dim.
4. log changed dim and generated dim for restore
5. inherit computation.
6. TODO: look into view list to see whether the view is associated with other,
6. look into view list to see whether the view is associated with other,
if so assgin equal dim according to previous view.
Args:
@ -552,7 +651,7 @@ class TraceIndice(object):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.append(unflated_args[i].meta["fwd_out"][0])
target_shape.extend(unflated_args[i].meta["fwd_out"])
# compute the value of -1
if -1 in target_shape:
@ -579,17 +678,36 @@ class TraceIndice(object):
dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
elif len_diff == 0:
# dim equal
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = []
dim_to = []
else:
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
idx_from = [origin_trace[i] for i in dim_from]
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
for i in dim_to:
self._add_dim(node_idx, i)
dim_from.reverse()
# search view list
for view_node, view_dict in self.indice_view_list.items():
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
for dim_to_i in dim_to:
for dim_from_i in dim_from:
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
@ -630,7 +748,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
@ -639,59 +757,82 @@ class TraceIndice(object):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_list):
node_name = get_node_name(node)
if node.op == "placeholder":
self._assign_all_indice(node, idx)
elif node.op == "call_method":
if "transpose" in node.name:
if "transpose" == node_name:
self._assign_transpose_indice(node, idx)
elif "permute" in node.name:
elif "permute" == node_name:
self._assign_permute_indice(node, idx)
elif "view" in node.name or "reshape" in node.name:
elif "view" == node_name or "reshape" == node_name:
self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" in node.name:
elif "unsqueeze" == node_name:
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
elif "new_ones" == node_name:
self._assign_ones_like_indice(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" in node.name:
self._assign_linear_indice(node, idx)
elif "cat" in node.name:
self._assign_cat_indice(node, idx)
elif "matmul" in node.name:
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx)
elif "dropout" in node.name:
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "sum" in node.name:
self._assign_sum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
elif any(i == node_name for i in ["size"]):
continue
else:
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]):
raise NotImplementedError(node_name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" == node_name:
self._assign_linear_indice(node, idx)
elif "cat" == node_name:
self._assign_cat_indice(node, idx)
elif "matmul" == node_name:
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" == node_name:
self._assign_ones_like_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
elif "sum" == node_name:
self._assign_sum_indice(node, idx)
elif "layer_norm" == node_name:
self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
elif "getitem" == node_name:
self._assign_getitem_indice(node, idx)
elif "addmm" == node_name:
self._assign_addmm_indice(node, idx)
elif "arange" == node_name:
self._assign_arange_indice(node, idx)
elif "tensor" == node_name:
self._assign_arange_indice(node, idx)
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
continue
else:
raise NotImplementedError(node_name, "function not implemented yet!")
elif node.op == "call_module":
node_name = get_module_node_name(node)
if "layernorm" == node_name:
self._assign_layernorm_indice(node, idx)
elif "embedding" == node_name:
self._assign_embedding_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
elif node.op == "output":

View File

@ -1,13 +1,15 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
logger = get_dist_logger()
def get_logger():
def get_logger() -> Any:
return logger
@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
def is_non_compute_node(node: Node) -> bool:
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node):
def is_non_compute_node_except_placeholder(node: Node) -> bool:
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node):
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list):
def find_idx_by_name(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def delete_free_var_from_last_use(user_to_last_uses):
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items():
for n in value:
if n.op == "placeholder":
user_to_last_uses[key].remove(n)
def find_chunk_all_input_nodes(nodes: List[Node]):
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
return input_nodes
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
output_nodes.append(node)
return input_nodes, output_nodes
def get_module_node_name(node: Node) -> str:
"""
get module class name
"""
node_targets = node.target.split(".")
module = node.graph.owning_module
for i in node_targets:
module = getattr(module, i)
module_name = str(module.__class__).split(".")[-1][:-2]
module_name = module_name.lower()
return module_name
def get_node_name(node: Node) -> str:
"""
get node name
"""
node_name = node.name
if "_" in node_name:
for i in range(len(node_name) - 1, -1, -1):
if node_name[i] == "_":
node_name = node_name[:i]
break
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
continue
else:
break
return node_name

View File

@ -1,94 +0,0 @@
import time
import torch
import torch.fx
from simple_evoformer import base_evoformer, openfold_evoformer
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
loop = 3
with torch.no_grad():
for _ in range(loop // 2 + 1):
if chunk_size:
model(node, pair, chunk_size)
else:
model(node, pair)
torch.cuda.synchronize()
time1 = time.time()
for _ in range(loop):
if chunk_size:
model(node, pair, chunk_size)
else:
model(node, pair)
torch.cuda.synchronize()
time2 = time.time()
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem))
def _build_autochunk(model, max_memory, node, pair):
# trace the module and replace codegen
graph = ColoTracer().trace(
model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
},
)
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# set code_gen
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# print
# code = graph.python_code("self").src
# print(code)
return gm
def benchmark_evoformer():
# init data and model
msa_len = 128
pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = base_evoformer().cuda()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory = None # min memory mode
autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair)
# build openfold
chunk_size = 64
openfold = openfold_evoformer().cuda()
# benchmark
_benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk")
if __name__ == "__main__":
benchmark_evoformer()

View File

@ -0,0 +1,122 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: List,
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
if concrete_args is None:
concrete_args = []
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
model.cuda()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
out_gm = flat_list(out_gm)
out_model = flat_list(out_model)
for out_gm_i, out_model_i in zip(out_gm, out_model):
assert torch.allclose(out_gm_i, out_model_i,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm_i - out_model_i))
return chunks
def run_test(
rank: int,
data_args: tuple,
max_memory: int,
get_model: Any,
get_data: Any,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = get_model()
meta_args, concrete_args = get_data(*data_args)
chunks = assert_codegen_run(
model,
meta_args=meta_args,
concrete_args=concrete_args,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)

View File

@ -0,0 +1,95 @@
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
return meta_args, concrete_args
def get_chunk_target() -> Dict:
return {
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
24: [(118, 123)],
}
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_block(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)

View File

@ -0,0 +1,90 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = EvoformerStack(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
c_s=384,
no_heads_msa=8,
no_heads_pair=4,
no_blocks=2, # 48
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
blocks_per_ckpt=None,
inf=1000000000.0,
eps=1e-08,
clear_cache_between_blocks=False,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
return meta_args, concrete_args
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)

View File

@ -0,0 +1,96 @@
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
HAS_REPO = True
except:
HAS_REPO = False
from test_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_model():
model = ExtraMSABlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
ckpt=False,
is_multimer=False,
).eval().cuda()
return model
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
meta_args = [
("m", node),
("z", pair),
("msa_mask", node_mask),
("pair_mask", pair_mask),
]
concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)]
return meta_args, concrete_args
def get_chunk_target() -> Dict:
return {
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
(33, 46)],
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
24: [(126, 131)],
}
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory):
run_func = partial(
run_test,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)

View File

@ -0,0 +1,120 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: List,
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
if concrete_args is None:
concrete_args = []
model = model()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
assert torch.allclose(out_gm["sample"], out_model["sample"],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm["sample"] - out_model["sample"]))
return chunks
def run_test(
rank: int,
model: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
meta_args, concrete_args = data
chunks = assert_codegen_run(
model,
meta_args=meta_args,
concrete_args=concrete_args,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
gpc.destroy()

View File

@ -0,0 +1,70 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
def get_data(shape: tuple) -> Tuple[List, List]:
sample = torch.randn(shape)
meta_args = [
("sample", sample),
]
concrete_args = [("timestep", 50)]
return meta_args, concrete_args
@pytest.mark.skipif(
True,
reason="not implemented",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [64])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(shape),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=64,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_progress=False,
)

View File

@ -1,163 +0,0 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask)
fx_out = gm(node, pair, node_mask, pair_mask)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerBlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"),
MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 24)

View File

@ -1,163 +0,0 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1, None)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask, None)
fx_out = gm(node, pair, node_mask, pair_mask, None)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = EvoformerStack(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
c_s=384,
no_heads_msa=8,
no_heads_pair=4,
no_blocks=2, # 48
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
blocks_per_ckpt=None,
inf=1000000000.0,
eps=1e-08,
clear_cache_between_blocks=False,
is_multimer=False,
).eval().cuda()
return model
def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_mask_trans": True,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_stack_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_evoformer_stack_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_evoformer_stack_codegen(0, 32, 64, None)

View File

@ -1,164 +0,0 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
with torch.no_grad():
non_fx_out = model(node, pair, node_mask, pair_mask)
fx_out = gm(node, pair, node_mask, pair_mask)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _build_openfold():
model = ExtraMSABlock(
c_m=256,
c_z=128,
c_hidden_msa_att=32,
c_hidden_opm=32,
c_hidden_mul=128,
c_hidden_pair_att=32,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.15,
inf=1e4,
eps=1e-4,
ckpt=False,
is_multimer=False,
).eval().cuda()
return model
def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = _build_openfold()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
node_mask = torch.randn(1, msa_len, pair_len).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_chunk_logits": 1024,
},
)
interp = MetaInfoProp(meta_graph)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"),
MetaTensor(pair, fake_device="cuda:0"),
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model,
meta_args={
"m": node.to(torch.device("meta")),
"z": pair.to(torch.device("meta")),
"msa_mask": node_mask.to(torch.device("meta")),
"pair_mask": pair_mask.to(torch.device("meta")),
},
concrete_args={
"chunk_size": None,
"_chunk_logits": 1024,
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@pytest.mark.skipif(
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_extramsa_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_extramsa_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_extramsa_codegen(0, 32, 64, None)

View File

@ -1,104 +0,0 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer, symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
with torch.no_grad():
non_fx_out = model(node, pair)
fx_out = gm(node, pair)
assert torch.allclose(non_fx_out[0], fx_out[0],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[0] - fx_out[0]))
assert torch.allclose(non_fx_out[1], fx_out[1],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(non_fx_out[1] - fx_out[1]))
def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = base_evoformer().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
# meta info prop
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace the module and replace codegen
graph = ColoTracer().trace(
model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair)
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason='torch version is lower than 1.12.0')
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_simple_evoformer_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_simple_evoformer_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_simple_evoformer_codegen(0, 32, 64, 25)

View File

@ -1,97 +0,0 @@
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from simple_evoformer import base_evoformer
HAS_REPO = True
except:
HAS_REPO = False
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
found_regions = [i["region"] for i in chunk_infos]
if msa_len == 32 and pair_len == 64:
if max_memory is None:
target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191),
(161, 166), (198, 203), (7, 57)]
elif max_memory == 20:
target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)]
elif max_memory == 25:
target_regions = [(144, 154), (369, 370)]
elif max_memory == 30:
target_regions = [(144, 154)]
else:
raise NotImplementedError()
else:
raise NotImplementedError()
assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % (
str(found_regions),
str(target_regions),
)
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = base_evoformer().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
chunk_infos = codegen.chunk_infos
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
gpc.destroy()
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0")
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
run_func = partial(
_test_simple_evoformer_search,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_simple_evoformer_search(0, 32, 64, 20)

View File

@ -0,0 +1,65 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from transformers import GPT2Config, GPT2Model
MODELS = [GPT2Model]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 256
def get_data(shape: tuple) -> Tuple[List, List]:
input_ids = torch.zeros(shape, dtype=torch.int64)
token_type_ids = torch.zeros(shape, dtype=torch.int64)
attention_mask = torch.ones(shape, dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
concrete_args = {"past_key_values": None}
sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"]
return meta_args, concrete_args, sequence
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
def test_gpt(model, shape, max_memory):
run_func = partial(
run_test,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=True,
print_mem=True,
print_progress=False,
)

View File

@ -0,0 +1,123 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
data: tuple,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
meta_args, concrete_args, sequence = data
if concrete_args is None:
concrete_args = {}
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
for k in out_model.keys():
if torch.is_tensor(out_gm[k]):
assert torch.equal(
out_model[k], out_gm[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
return chunks
def run_test(
rank: int,
model: Any,
config: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
chunks = assert_codegen_run(
model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
gpc.destroy()