mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support diffusion for autochunk (#2621)
* add alphafold benchmark * renae alphafold test * rename tests * rename diffuser * renme * rename * update transformer * update benchmark * update benchmark * update bench memory * update transformer benchmark * rename * support diffuser * support unet metainfo prop * fix bug and simplify code * update linear and support some op * optimize max region search, support conv * update unet test * support some op * support groupnorm and interpolate * update flow search * add fix dim in node flow * fix utils * rename * support diffusion * update diffuser * update chunk search * optimize imports * import * finish autochunkpull/2635/head
parent
291b051171
commit
6ba8364881
|
@ -9,18 +9,7 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
|
||||||
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
||||||
|
|
||||||
if AUTOCHUNK_AVAILABLE:
|
if AUTOCHUNK_AVAILABLE:
|
||||||
from torch.fx.graph import (
|
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
|
||||||
CodeGen,
|
|
||||||
PythonCode,
|
|
||||||
_custom_builtins,
|
|
||||||
_CustomBuiltin,
|
|
||||||
_format_target,
|
|
||||||
_is_from_torch,
|
|
||||||
_Namespace,
|
|
||||||
_origin_type_map,
|
|
||||||
inplace_methods,
|
|
||||||
magic_methods,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||||
|
|
||||||
|
@ -143,7 +132,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _replace_ones_like(
|
def _replace_new_tensor_like_shape(
|
||||||
search_chunk: SearchChunk,
|
search_chunk: SearchChunk,
|
||||||
chunk_infos: List[Dict],
|
chunk_infos: List[Dict],
|
||||||
region_idx: int,
|
region_idx: int,
|
||||||
|
@ -154,7 +143,7 @@ def _replace_ones_like(
|
||||||
"""
|
"""
|
||||||
add chunk slice for new tensor op such as ones like
|
add chunk slice for new tensor op such as ones like
|
||||||
"""
|
"""
|
||||||
if "ones_like" in node.name:
|
if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
|
||||||
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
|
@ -166,6 +155,33 @@ def _replace_ones_like(
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_new_tensor_shape(
|
||||||
|
search_chunk: SearchChunk,
|
||||||
|
chunk_infos: List[Dict],
|
||||||
|
region_idx: int,
|
||||||
|
node_idx: int,
|
||||||
|
node: Node,
|
||||||
|
body: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
add chunk slice for new tensor op such as ones
|
||||||
|
"""
|
||||||
|
if get_node_name(node) in ["ones", "zeros", "empty"]:
|
||||||
|
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||||
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
|
if chunk_dim is None:
|
||||||
|
return
|
||||||
|
if get_node_shape(meta_node)[chunk_dim] == 1:
|
||||||
|
return
|
||||||
|
origin_shape = str(node.args)
|
||||||
|
new_shape = list(node.args)
|
||||||
|
new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
|
||||||
|
new_shape = str(new_shape)
|
||||||
|
new_shape = new_shape.replace("'", "")
|
||||||
|
body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
def _add_node_slice(
|
def _add_node_slice(
|
||||||
chunk_nodes: List[Node],
|
chunk_nodes: List[Node],
|
||||||
region_idx: int,
|
region_idx: int,
|
||||||
|
@ -265,8 +281,10 @@ def emit_code_with_chunk(
|
||||||
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
||||||
# replace output var with chunk var
|
# replace output var with chunk var
|
||||||
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
||||||
# ones like
|
# new tensor like
|
||||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||||
|
# new tensor
|
||||||
|
body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||||
# reassgin reshape size
|
# reassgin reshape size
|
||||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||||
body[-1] = " " + body[-1]
|
body[-1] = " " + body[-1]
|
||||||
|
|
|
@ -8,14 +8,7 @@ from .reorder_graph import ReorderGraph
|
||||||
from .select_chunk import SelectChunk
|
from .select_chunk import SelectChunk
|
||||||
from .trace_flow import TraceFlow
|
from .trace_flow import TraceFlow
|
||||||
from .trace_indice import TraceIndice
|
from .trace_indice import TraceIndice
|
||||||
from .utils import (
|
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||||
NodeMgr,
|
|
||||||
find_chunk_compute_input_and_output_nodes,
|
|
||||||
get_logger,
|
|
||||||
get_node_shape,
|
|
||||||
is_non_compute_node,
|
|
||||||
is_non_compute_node_except_placeholder,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchChunk(object):
|
class SearchChunk(object):
|
||||||
|
@ -75,8 +68,8 @@ class SearchChunk(object):
|
||||||
max_chunk_region_list = []
|
max_chunk_region_list = []
|
||||||
while True:
|
while True:
|
||||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||||
cur_node_idx = max_chunk_region[1]
|
cur_node_idx = max_chunk_region[1] + 1
|
||||||
if cur_node_idx == len(active_nodes) - 1:
|
if cur_node_idx >= len(active_nodes) - 1:
|
||||||
break
|
break
|
||||||
max_chunk_region_list.append(max_chunk_region)
|
max_chunk_region_list.append(max_chunk_region)
|
||||||
|
|
||||||
|
@ -135,6 +128,7 @@ class SearchChunk(object):
|
||||||
min_active_node_num = min(active_node_num[free_var_num:])
|
min_active_node_num = min(active_node_num[free_var_num:])
|
||||||
threshold = max(free_var_num, min_active_node_num)
|
threshold = max(free_var_num, min_active_node_num)
|
||||||
|
|
||||||
|
# normal search
|
||||||
# from peak_node to free_var
|
# from peak_node to free_var
|
||||||
inside_flag = False
|
inside_flag = False
|
||||||
chunk_region_start = free_var_num
|
chunk_region_start = free_var_num
|
||||||
|
@ -144,7 +138,6 @@ class SearchChunk(object):
|
||||||
if inside_flag and active_node_num[i] > threshold:
|
if inside_flag and active_node_num[i] > threshold:
|
||||||
chunk_region_start = i + 1
|
chunk_region_start = i + 1
|
||||||
break
|
break
|
||||||
|
|
||||||
# from peak_node to len-2
|
# from peak_node to len-2
|
||||||
inside_flag = False
|
inside_flag = False
|
||||||
chunk_region_end = len(active_node) - 1
|
chunk_region_end = len(active_node) - 1
|
||||||
|
@ -155,6 +148,22 @@ class SearchChunk(object):
|
||||||
chunk_region_end = i
|
chunk_region_end = i
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# if normal search fails, use approximate search
|
||||||
|
if (chunk_region_end - chunk_region_start) > 250:
|
||||||
|
window_size = 100
|
||||||
|
# search min for start
|
||||||
|
min_num = 1e3
|
||||||
|
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
|
||||||
|
if active_node_num[i] < min_num:
|
||||||
|
min_num = active_node_num[i]
|
||||||
|
chunk_region_start = i
|
||||||
|
# search min for end
|
||||||
|
min_num = 1e3
|
||||||
|
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
|
||||||
|
if active_node_num[i] < min_num:
|
||||||
|
min_num = active_node_num[i]
|
||||||
|
chunk_region_end = i
|
||||||
|
|
||||||
# avoid chunk regions overlap
|
# avoid chunk regions overlap
|
||||||
if chunk_regions is not None:
|
if chunk_regions is not None:
|
||||||
for i in chunk_regions:
|
for i in chunk_regions:
|
||||||
|
@ -271,12 +280,6 @@ class SearchChunk(object):
|
||||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||||
return best_chunk_region
|
return best_chunk_region
|
||||||
|
|
||||||
def _stop_search(self, init_mem_peak, mem_peak):
|
|
||||||
sorted_init_mem_peak = sorted(init_mem_peak)
|
|
||||||
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def search_region(self) -> Dict:
|
def search_region(self) -> Dict:
|
||||||
"""
|
"""
|
||||||
Search all chunk regions:
|
Search all chunk regions:
|
||||||
|
@ -291,11 +294,7 @@ class SearchChunk(object):
|
||||||
get_logger().info("AutoChunk start searching chunk regions")
|
get_logger().info("AutoChunk start searching chunk regions")
|
||||||
|
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
(
|
init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||||
init_mem_peak,
|
|
||||||
_,
|
|
||||||
active_node,
|
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -304,18 +303,13 @@ class SearchChunk(object):
|
||||||
break
|
break
|
||||||
chunk_infos.append(chunk_info)
|
chunk_infos.append(chunk_info)
|
||||||
|
|
||||||
(
|
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
mem_peak,
|
self.node_mgr.get_node_list(), chunk_infos)
|
||||||
_,
|
|
||||||
active_node,
|
|
||||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
|
|
||||||
|
|
||||||
if self.print_progress:
|
if self.print_progress:
|
||||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||||
|
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
|
||||||
break
|
|
||||||
if self.print_mem:
|
if self.print_mem:
|
||||||
self.print_mem = False
|
self.print_mem = False
|
||||||
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||||
|
|
|
@ -100,6 +100,16 @@ class TraceFlow(object):
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# get fix dim
|
||||||
|
arg_fix_dim = []
|
||||||
|
if cur_node_dim is not None:
|
||||||
|
for i in cur_node_fix_dim:
|
||||||
|
fix_dim_source = cur_node_source[i]
|
||||||
|
if arg_idx in fix_dim_source:
|
||||||
|
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||||
|
if arg_node in all_node_info:
|
||||||
|
arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||||
|
|
||||||
# find arg dim
|
# find arg dim
|
||||||
if cur_node_dim is not None:
|
if cur_node_dim is not None:
|
||||||
# dim is computed
|
# dim is computed
|
||||||
|
@ -109,6 +119,9 @@ class TraceFlow(object):
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
else:
|
else:
|
||||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||||
|
# chunk dim cannot be in fix dims
|
||||||
|
if arg_dim in arg_fix_dim:
|
||||||
|
return False
|
||||||
# chunk dim should be None if shape size is 1
|
# chunk dim should be None if shape size is 1
|
||||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
@ -120,19 +133,16 @@ class TraceFlow(object):
|
||||||
else:
|
else:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
|
||||||
# get fix dim
|
# add arg rest dim as fix dim
|
||||||
arg_fix_dim = []
|
arg_fix_dim = list(range(len(get_node_shape(arg_node))))
|
||||||
if cur_node_dim is not None:
|
if arg_dim is not None:
|
||||||
for i in cur_node_fix_dim:
|
arg_fix_dim.remove(arg_dim)
|
||||||
fix_dim_source = cur_node_source[i]
|
|
||||||
if arg_idx in fix_dim_source:
|
|
||||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
|
||||||
|
|
||||||
# if already in node_info, arg dim must be same
|
# if already in node_info, arg dim must be same
|
||||||
if arg_node in all_node_info:
|
if arg_node in all_node_info:
|
||||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||||
return False
|
return False
|
||||||
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
all_node_info[arg_node]["fix_dim"] = arg_fix_dim
|
||||||
# else add it to list
|
# else add it to list
|
||||||
else:
|
else:
|
||||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||||
|
@ -164,6 +174,8 @@ class TraceFlow(object):
|
||||||
continue
|
continue
|
||||||
if is_non_compute_node(arg):
|
if is_non_compute_node(arg):
|
||||||
continue
|
continue
|
||||||
|
if get_node_shape(arg) is None:
|
||||||
|
continue
|
||||||
arg_list.append(arg)
|
arg_list.append(arg)
|
||||||
flow_flag = self._assgin_single_node_flow(
|
flow_flag = self._assgin_single_node_flow(
|
||||||
arg,
|
arg,
|
||||||
|
@ -180,29 +192,6 @@ class TraceFlow(object):
|
||||||
if flow_flag == False:
|
if flow_flag == False:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(arg_list) >= 2:
|
|
||||||
# need to mark fix dim
|
|
||||||
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
|
|
||||||
for arg in arg_list:
|
|
||||||
if get_node_shape(arg) is None:
|
|
||||||
continue
|
|
||||||
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
|
|
||||||
continue
|
|
||||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
|
||||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
|
||||||
arg_shape = get_node_shape(arg)
|
|
||||||
# add all dim as fix dim except chunk dim
|
|
||||||
for i, shape in enumerate(arg_shape):
|
|
||||||
if shape != 1 and i != cur_node_chunk_dim:
|
|
||||||
if i == arg_chunk_dim:
|
|
||||||
return None
|
|
||||||
if i not in arg_fix_dim:
|
|
||||||
arg_fix_dim.append(i)
|
|
||||||
elif any(i == get_node_name(cur_node)
|
|
||||||
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
cur_node_list = next_node_list
|
cur_node_list = next_node_list
|
||||||
return all_node_info
|
return all_node_info
|
||||||
|
|
||||||
|
|
|
@ -150,7 +150,7 @@ class TraceIndice(object):
|
||||||
for i in range(len(node_from_indice)):
|
for i in range(len(node_from_indice)):
|
||||||
self._inherit_indice(node_from, i, node_to, i, init=True)
|
self._inherit_indice(node_from, i, node_to, i, init=True)
|
||||||
|
|
||||||
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
|
def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
|
||||||
"""
|
"""
|
||||||
inheirt indice from node without init
|
inheirt indice from node without init
|
||||||
"""
|
"""
|
||||||
|
@ -308,14 +308,14 @@ class TraceIndice(object):
|
||||||
node (node)
|
node (node)
|
||||||
node_idx (int)
|
node_idx (int)
|
||||||
"""
|
"""
|
||||||
if len(node.args) == 2:
|
|
||||||
_, weight = node.args
|
|
||||||
else:
|
|
||||||
_, weight, _ = node.args
|
|
||||||
|
|
||||||
self._assign_indice_as_input(node, node_idx)
|
self._assign_indice_as_input(node, node_idx)
|
||||||
self._inherit_indice(weight, 1, node, -1)
|
|
||||||
|
|
||||||
|
if len(node.args) >= 2:
|
||||||
|
weight = node.args[1]
|
||||||
|
self._inherit_indice(weight, 1, node, -1)
|
||||||
|
else:
|
||||||
|
self._del_dim(node_idx, -1)
|
||||||
|
self._add_dim(node_idx, -1)
|
||||||
self._mark_computation(node, node_idx, [-1])
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
|
||||||
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
|
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
|
||||||
|
@ -327,13 +327,35 @@ class TraceIndice(object):
|
||||||
node_idx (int)
|
node_idx (int)
|
||||||
"""
|
"""
|
||||||
bias, input_node, weight = node.args
|
bias, input_node, weight = node.args
|
||||||
|
assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2
|
||||||
self._assign_indice_as_input(node, node_idx, input_node)
|
self._assign_indice_as_input(node, node_idx, input_node)
|
||||||
self._inherit_indice(weight, 1, node, -1)
|
self._inherit_indice(weight, 1, node, -1)
|
||||||
self._inherit_indice(bias, -1, node, -1)
|
self._inherit_more_indice_from_node_with_exclude(bias, node)
|
||||||
|
|
||||||
self._mark_computation(node, node_idx, [-1])
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
|
||||||
|
def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Assign indice for baddbmm(batch add and batch matmul) op.
|
||||||
|
add, matmul_left, matmul_right = args
|
||||||
|
out = add + (matmul_left x matmul_right)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
add, matmul_left, matmul_right = node.args
|
||||||
|
|
||||||
|
assert get_node_shape(add) == get_node_shape(node)
|
||||||
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||||
|
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||||
|
# matmul
|
||||||
|
self._inherit_indice(matmul_right, -1, node, -1)
|
||||||
|
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1])
|
||||||
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
# add
|
||||||
|
self._inherit_more_indice_from_node_with_exclude(add, node)
|
||||||
|
|
||||||
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
|
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Assign indice for matmul op.
|
Assign indice for matmul op.
|
||||||
|
@ -349,11 +371,53 @@ class TraceIndice(object):
|
||||||
|
|
||||||
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||||
self._inherit_indice(matmul_right, -1, node, -1)
|
|
||||||
|
|
||||||
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
|
self._inherit_indice(matmul_right, -1, node, -1)
|
||||||
|
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2])
|
||||||
self._mark_computation(node, node_idx, [-1])
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
|
||||||
|
def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Assign indice for conv2d op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
# get conv module
|
||||||
|
node_targets = node.target.split(".")
|
||||||
|
conv_module = node.graph.owning_module
|
||||||
|
for i in node_targets:
|
||||||
|
conv_module = getattr(conv_module, i)
|
||||||
|
assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented"
|
||||||
|
|
||||||
|
# get conv input
|
||||||
|
assert len(node.args) == 1
|
||||||
|
input_node = node.args[0]
|
||||||
|
assert len(get_node_shape(input_node)) == 4
|
||||||
|
|
||||||
|
# assgin index
|
||||||
|
self._assign_indice_as_input(node, node_idx, input_node)
|
||||||
|
self._del_dim(node_idx, 1)
|
||||||
|
self._add_dim(node_idx, 1)
|
||||||
|
self._mark_computation(node, node_idx, [1, 2, 3])
|
||||||
|
|
||||||
|
def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Assign indice for interpolate op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
# get conv input
|
||||||
|
assert node.kwargs['size'] is None
|
||||||
|
assert len(get_node_shape(node)) == 4
|
||||||
|
|
||||||
|
# assgin index
|
||||||
|
self._assign_indice_as_input(node, node_idx)
|
||||||
|
self._mark_computation(node, node_idx, [-1, -2])
|
||||||
|
|
||||||
def _assign_layernorm_indice(self, node, idx):
|
def _assign_layernorm_indice(self, node, idx):
|
||||||
"""
|
"""
|
||||||
Assign indice for layernorm op.
|
Assign indice for layernorm op.
|
||||||
|
@ -367,6 +431,18 @@ class TraceIndice(object):
|
||||||
self._assign_indice_as_input(node, idx)
|
self._assign_indice_as_input(node, idx)
|
||||||
self._mark_computation(node, idx, [-1])
|
self._mark_computation(node, idx, [-1])
|
||||||
|
|
||||||
|
def _assign_groupnorm_indice(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign indice for groupnorm op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
assert len(get_node_shape(node)) == 4
|
||||||
|
self._assign_indice_as_input(node, idx)
|
||||||
|
self._mark_computation(node, idx, [-1, -2, -3])
|
||||||
|
|
||||||
def _assign_elementwise_indice(self, node, idx):
|
def _assign_elementwise_indice(self, node, idx):
|
||||||
"""
|
"""
|
||||||
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
||||||
|
@ -382,13 +458,13 @@ class TraceIndice(object):
|
||||||
for node_in in node.args:
|
for node_in in node.args:
|
||||||
if type(node_in) == type(node):
|
if type(node_in) == type(node):
|
||||||
nodes_in.append(node_in)
|
nodes_in.append(node_in)
|
||||||
self._inherit_more_indice_from_node(node_in, node)
|
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||||
|
|
||||||
def _assgin_no_change_indice(self, node, idx):
|
def _assgin_no_change_indice(self, node, idx):
|
||||||
self._assign_indice_as_input(node, idx)
|
self._assign_indice_as_input(node, idx)
|
||||||
for node_in in node.args:
|
for node_in in node.args:
|
||||||
if type(node_in) == type(node):
|
if type(node_in) == type(node):
|
||||||
self._inherit_more_indice_from_node(node_in, node)
|
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||||
|
|
||||||
def _assign_einsum_indice(self, node, idx):
|
def _assign_einsum_indice(self, node, idx):
|
||||||
"""
|
"""
|
||||||
|
@ -469,17 +545,6 @@ class TraceIndice(object):
|
||||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||||
self._add_dim(node_idx, dim_idx)
|
self._add_dim(node_idx, dim_idx)
|
||||||
|
|
||||||
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
|
|
||||||
"""
|
|
||||||
Assign indice for oneslike op.
|
|
||||||
1. assign new indice for all dim
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (node)
|
|
||||||
node_idx (int)
|
|
||||||
"""
|
|
||||||
self._assign_all_indice(node, node_idx)
|
|
||||||
|
|
||||||
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
|
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Assign indice for cat op.
|
Assign indice for cat op.
|
||||||
|
@ -491,7 +556,7 @@ class TraceIndice(object):
|
||||||
nodes_in = flat_list(node.args[0])
|
nodes_in = flat_list(node.args[0])
|
||||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||||
for n in nodes_in[1:]:
|
for n in nodes_in[1:]:
|
||||||
self._inherit_more_indice_from_node(n, node)
|
self._inherit_more_indice_from_node_with_exclude(n, node)
|
||||||
cat_dim = node.kwargs["dim"]
|
cat_dim = node.kwargs["dim"]
|
||||||
self._del_dim(node_idx, cat_dim)
|
self._del_dim(node_idx, cat_dim)
|
||||||
self._add_dim(node_idx, cat_dim)
|
self._add_dim(node_idx, cat_dim)
|
||||||
|
@ -508,33 +573,10 @@ class TraceIndice(object):
|
||||||
self._add_dim(node_idx, 0)
|
self._add_dim(node_idx, 0)
|
||||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||||
for n in nodes_in[1:]:
|
for n in nodes_in[1:]:
|
||||||
self._inherit_more_indice_from_node(n, node)
|
self._inherit_more_indice_from_node_with_exclude(n, node)
|
||||||
cat_dim = node.kwargs["dim"]
|
cat_dim = node.kwargs["dim"]
|
||||||
self._del_dim(node_idx, cat_dim)
|
self._del_dim(node_idx, cat_dim)
|
||||||
|
|
||||||
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
|
|
||||||
"""
|
|
||||||
Assign indice for arange op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (node)
|
|
||||||
node_idx (int)
|
|
||||||
"""
|
|
||||||
self._assign_all_indice(node, node_idx)
|
|
||||||
|
|
||||||
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
|
|
||||||
"""
|
|
||||||
Assign indice for tensor op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node (node)
|
|
||||||
node_idx (int)
|
|
||||||
"""
|
|
||||||
if len(get_node_shape(node)) == 0:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
|
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
|
||||||
"""
|
"""
|
||||||
Assign indice for embedding op.
|
Assign indice for embedding op.
|
||||||
|
@ -763,10 +805,10 @@ class TraceIndice(object):
|
||||||
self._assign_unsqueeze_indice(node, idx)
|
self._assign_unsqueeze_indice(node, idx)
|
||||||
elif "split" == node_name:
|
elif "split" == node_name:
|
||||||
self._assign_split_indice(node, idx)
|
self._assign_split_indice(node, idx)
|
||||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
|
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
||||||
self._assgin_no_change_indice(node, idx)
|
self._assgin_no_change_indice(node, idx)
|
||||||
elif "new_ones" == node_name:
|
elif "new_ones" == node_name:
|
||||||
self._assign_ones_like_indice(node, idx)
|
self._assign_all_indice(node, idx)
|
||||||
elif any(i == node_name for i in ["size"]):
|
elif any(i == node_name for i in ["size"]):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -776,25 +818,15 @@ class TraceIndice(object):
|
||||||
self._assign_linear_indice(node, idx)
|
self._assign_linear_indice(node, idx)
|
||||||
elif "cat" == node_name:
|
elif "cat" == node_name:
|
||||||
self._assign_cat_indice(node, idx)
|
self._assign_cat_indice(node, idx)
|
||||||
elif "matmul" == node_name:
|
elif any(n == node_name for n in ["matmul", "bmm"]):
|
||||||
self._assign_matmul_indice(node, idx)
|
self._assign_matmul_indice(node, idx)
|
||||||
elif "softmax" == node_name:
|
elif "softmax" == node_name:
|
||||||
self._assign_softmax_indice(node, idx)
|
self._assign_softmax_indice(node, idx)
|
||||||
elif any(n == node_name for n in [
|
elif any(n == node_name for n in [
|
||||||
"mul",
|
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
|
||||||
"add",
|
"sin", "cos"
|
||||||
"sigmoid",
|
|
||||||
"relu",
|
|
||||||
"sub",
|
|
||||||
"truediv",
|
|
||||||
"pow",
|
|
||||||
"dropout",
|
|
||||||
"where",
|
|
||||||
"tanh",
|
|
||||||
]):
|
]):
|
||||||
self._assign_elementwise_indice(node, idx)
|
self._assign_elementwise_indice(node, idx)
|
||||||
elif "ones_like" == node_name:
|
|
||||||
self._assign_ones_like_indice(node, idx)
|
|
||||||
elif "einsum" == node_name:
|
elif "einsum" == node_name:
|
||||||
self._assign_einsum_indice(node, idx)
|
self._assign_einsum_indice(node, idx)
|
||||||
elif "sum" == node_name:
|
elif "sum" == node_name:
|
||||||
|
@ -805,10 +837,12 @@ class TraceIndice(object):
|
||||||
self._assign_getitem_indice(node, idx)
|
self._assign_getitem_indice(node, idx)
|
||||||
elif "addmm" == node_name:
|
elif "addmm" == node_name:
|
||||||
self._assign_addmm_indice(node, idx)
|
self._assign_addmm_indice(node, idx)
|
||||||
elif "arange" == node_name:
|
elif "baddbmm" == node_name:
|
||||||
self._assign_arange_indice(node, idx)
|
self._assign_baddbmm_indice(node, idx)
|
||||||
elif "tensor" == node_name:
|
elif "interpolate" == node_name:
|
||||||
self._assign_arange_indice(node, idx)
|
self._assign_interpolate_indice(node, idx)
|
||||||
|
elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]):
|
||||||
|
self._assign_all_indice(node, idx)
|
||||||
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
|
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -817,9 +851,15 @@ class TraceIndice(object):
|
||||||
node_name = get_module_node_name(node)
|
node_name = get_module_node_name(node)
|
||||||
if "layernorm" == node_name:
|
if "layernorm" == node_name:
|
||||||
self._assign_layernorm_indice(node, idx)
|
self._assign_layernorm_indice(node, idx)
|
||||||
|
elif "groupnorm" == node_name:
|
||||||
|
self._assign_groupnorm_indice(node, idx)
|
||||||
elif "embedding" == node_name:
|
elif "embedding" == node_name:
|
||||||
self._assign_embedding_indice(node, idx)
|
self._assign_embedding_indice(node, idx)
|
||||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
|
elif "linear" == node_name:
|
||||||
|
self._assign_linear_indice(node, idx)
|
||||||
|
elif "conv2d" == node_name:
|
||||||
|
self._assign_conv2d_indice(node, idx)
|
||||||
|
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]):
|
||||||
self._assign_elementwise_indice(node, idx)
|
self._assign_elementwise_indice(node, idx)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(node_name, "module not implemented yet!")
|
raise NotImplementedError(node_name, "module not implemented yet!")
|
||||||
|
|
|
@ -22,6 +22,7 @@ def assert_codegen_run(
|
||||||
concrete_args: List = None,
|
concrete_args: List = None,
|
||||||
max_memory: int = None,
|
max_memory: int = None,
|
||||||
print_mem: bool = False,
|
print_mem: bool = False,
|
||||||
|
print_est_mem: bool = False,
|
||||||
print_progress: bool = False,
|
print_progress: bool = False,
|
||||||
print_code: bool = False,
|
print_code: bool = False,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
|
@ -35,13 +36,14 @@ def assert_codegen_run(
|
||||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||||
concrete_args={k: v for k, v in concrete_args},
|
concrete_args={k: v for k, v in concrete_args},
|
||||||
)
|
)
|
||||||
|
model = model.cuda().eval()
|
||||||
interp = MetaInfoProp(meta_graph)
|
interp = MetaInfoProp(meta_graph)
|
||||||
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
|
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
|
||||||
interp.propagate(*meta_tensors)
|
interp.propagate(*meta_tensors)
|
||||||
codegen = AutoChunkCodeGen(
|
codegen = AutoChunkCodeGen(
|
||||||
meta_graph,
|
meta_graph,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_mem=print_mem,
|
print_mem=print_est_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
chunks = codegen.chunk_infos
|
chunks = codegen.chunk_infos
|
||||||
|
@ -61,17 +63,29 @@ def assert_codegen_run(
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
if print_code:
|
if print_code:
|
||||||
print(code)
|
print(code)
|
||||||
assert "chunk_result = None; chunk_size = None;" in code
|
assert "chunk_size = None; " in code
|
||||||
|
|
||||||
# assert result
|
# assert result
|
||||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||||
model.cuda().eval()
|
model.cuda().eval()
|
||||||
gm.eval()
|
gm.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out_gm = gm(*inputs)
|
if print_mem:
|
||||||
out_model = model(*inputs)
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
now_mem_gm = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||||
|
if print_mem:
|
||||||
|
max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
now_mem_ori = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||||
|
if print_mem:
|
||||||
|
max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
|
||||||
|
|
||||||
assert torch.allclose(out_gm["sample"], out_model["sample"],
|
assert torch.allclose(out_gm["sample"], out_model["sample"],
|
||||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
torch.abs(out_gm["sample"] - out_model["sample"]))
|
torch.abs(out_gm["sample"] - out_model["sample"]))
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
@ -82,9 +96,10 @@ def run_test(
|
||||||
model: Any,
|
model: Any,
|
||||||
data: tuple,
|
data: tuple,
|
||||||
max_memory: int,
|
max_memory: int,
|
||||||
print_code: bool,
|
print_code: bool = False,
|
||||||
print_mem: bool,
|
print_mem: bool = False,
|
||||||
print_progress: bool,
|
print_est_mem: bool = False,
|
||||||
|
print_progress: bool = False,
|
||||||
get_chunk_target: Any = None,
|
get_chunk_target: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# launch colossalai
|
# launch colossalai
|
||||||
|
@ -106,6 +121,7 @@ def run_test(
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
print_code=print_code,
|
print_code=print_code,
|
||||||
print_mem=print_mem,
|
print_mem=print_mem,
|
||||||
|
print_est_mem=print_est_mem,
|
||||||
print_progress=print_progress,
|
print_progress=print_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -17,10 +17,9 @@ from test_autochunk_diffuser_utils import run_test
|
||||||
|
|
||||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGTH = 5
|
HEIGHT = 448
|
||||||
HEIGHT = 224
|
WIDTH = 448
|
||||||
WIDTH = 224
|
|
||||||
IN_CHANNELS = 3
|
IN_CHANNELS = 3
|
||||||
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
|
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
|
||||||
|
|
||||||
|
@ -34,26 +33,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||||
return meta_args, concrete_args
|
return meta_args, concrete_args
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
True,
|
|
||||||
reason="not implemented",
|
|
||||||
)
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||||
reason="torch version is lower than 1.12.0",
|
reason="torch version is lower than 1.12.0",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
||||||
@pytest.mark.parametrize("max_memory", [64])
|
@pytest.mark.parametrize("max_memory", [None])
|
||||||
def test_evoformer_block(model, shape, max_memory):
|
def test_evoformer_block(model, shape, max_memory):
|
||||||
run_func = partial(
|
run_func = partial(
|
||||||
run_test,
|
run_test,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
model=model,
|
model=model,
|
||||||
data=get_data(shape),
|
data=get_data(shape),
|
||||||
print_code=False,
|
|
||||||
print_mem=False,
|
|
||||||
print_progress=False,
|
|
||||||
)
|
)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -62,9 +54,10 @@ if __name__ == "__main__":
|
||||||
run_test(
|
run_test(
|
||||||
rank=0,
|
rank=0,
|
||||||
data=get_data(LATENTS_SHAPE),
|
data=get_data(LATENTS_SHAPE),
|
||||||
max_memory=64,
|
max_memory=None,
|
||||||
model=UNet2DModel,
|
model=UNet2DModel,
|
||||||
print_code=False,
|
print_code=False,
|
||||||
print_mem=False,
|
print_mem=False,
|
||||||
|
print_est_mem=False,
|
||||||
print_progress=False,
|
print_progress=False,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue