From 1a6d2a740be33d769111ed03104bb5fa73b2ad50 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:14:45 +0800 Subject: [PATCH] take apart chunk code gen --- colossalai/autochunk/autochunk_codegen.py | 497 ++++ colossalai/autochunk/chunk_codegen.py | 2364 ------------------- colossalai/autochunk/chunk_region_search.py | 211 ++ colossalai/autochunk/chunk_selector.py | 221 ++ colossalai/autochunk/index_tracer.py | 1056 +++++++++ colossalai/autochunk/memory_estiamtor.py | 318 +++ colossalai/autochunk/utils.py | 95 + tests/test_autochunk/benchmark_autochunk.py | 12 +- tests/test_autochunk/test_autochunk.py | 4 +- 9 files changed, 2408 insertions(+), 2370 deletions(-) create mode 100644 colossalai/autochunk/autochunk_codegen.py delete mode 100644 colossalai/autochunk/chunk_codegen.py create mode 100644 colossalai/autochunk/chunk_region_search.py create mode 100644 colossalai/autochunk/chunk_selector.py create mode 100644 colossalai/autochunk/index_tracer.py create mode 100644 colossalai/autochunk/memory_estiamtor.py create mode 100644 colossalai/autochunk/utils.py diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py new file mode 100644 index 000000000..58a8c3751 --- /dev/null +++ b/colossalai/autochunk/autochunk_codegen.py @@ -0,0 +1,497 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch +from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, +) +from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + +import colossalai + +from .chunk_region_search import ChunkRegionSearch +from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape + +CODEGEN_AVAILABLE = True +__all__ = ["AutoChunkCodeGen"] + + +def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): + new_shape = "[" + for idx, i in enumerate(shape): + if idx == chunk_dim: + new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) + else: + new_shape += ":" + new_shape += ", " + new_shape = new_shape[:-2] + "]" + return new_shape + + +def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): + input_node = chunk_input[0] + out_shape = get_node_shape(chunk_output) + out_str = str(list(out_shape)) + context = ( + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" + % (out_str, input_node.name, input_node.name, chunk_size) + ) + context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) + return context + + +def _gen_loop_end( + chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list +): + chunk_outputs_name = chunk_outputs.name + chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) + chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape + chunk_slice = _gen_chunk_slice_dim( + chunk_outputs_dim, "chunk_idx", chunk_output_shape + ) + context = " chunk_result%s = %s; %s = None\n" % ( + chunk_slice, + chunk_outputs_name, + chunk_outputs_name, + ) + context += ( + chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" + ) + + # determine if its the last use for chunk input + for chunk_input in chunk_inputs + chunk_non_compute_inputs: + if all( + [ + find_idx_by_name(user.name, node_list) <= chunk_outputs_idx + for user in chunk_input.users.keys() + ] + ): + context += "; %s = None" % chunk_input.name + + context += "\n" + return context + + +def _replace_name(context, name_from, name_to): + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] + for p in patterns: + source = p[0] + name_from + p[1] + target = p[0] + name_to + p[1] + if source in context: + context = context.replace(source, target) + return context + + +def _replace_reshape_size(context, node_name, reshape_size_dict): + if node_name not in reshape_size_dict: + return context + for size_name, size_value in reshape_size_dict[node_name].items(): + context = context.replace(size_name, size_value) + return context + + +def emit_code_with_chunk( + body, + nodes, + emit_node_func, + delete_unused_value_func, + chunk_region_search, + chunk_infos, +): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + node_list = list(nodes) + + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + + chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] + + node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) + node_idx = 0 + region_idx = 0 + within_chunk_region = False + + while node_idx < len(node_list): + node = node_list[node_idx] + + if node_idx in chunk_starts: + within_chunk_region = True + region_idx = chunk_starts.index(node_idx) + body.append( + _gen_loop_start( + chunk_inputs[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], + chunk_infos[region_idx]["chunk_size"], + ) + ) + + if within_chunk_region: + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim[0], "chunk_idx", get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + # ones like + if "ones_like" in node.name: + meta_node = chunk_region_search.index_tracer.node_list[node_idx] + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ + "chunk_dim" + ] + if get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node][ + "chunk_dim" + ] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) + body[-1] = _replace_reshape_size( + body[-1], node.name, chunk_infos[region_idx]["reshape_size"] + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) + else: + emit_node_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, chunk_inputs_names) + + if node_idx in chunk_ends: + body.append( + _gen_loop_end( + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + chunk_outputs[region_idx], + chunk_outputs_dim[region_idx], + node_list, + ) + ) + within_chunk_region = False + + node_idx += 1 + + +if CODEGEN_AVAILABLE: + + class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None): + super().__init__() + self.meta_graph = meta_graph + self.max_memory = max_memory + self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) + self.chunk_infos = self.chunk_region_search.search_region() + + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin( + "import colossalai", colossalai + ) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" + + typename = _type_repr(o) + + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") + else: + body.append("\n") + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ( + "" if not node.args else f" = {repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) + return + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk( + body, + nodes, + emit_node, + delete_unused_values, + self.chunk_region_search, + self.chunk_infos, + ) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join( + [f'{wrap_name}("{name}")' for name in wrapped_fns] + ) + else: + wrap_stmts = "" + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue + + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + # print(fn_code) + return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/chunk_codegen.py b/colossalai/autochunk/chunk_codegen.py deleted file mode 100644 index 7a5d06689..000000000 --- a/colossalai/autochunk/chunk_codegen.py +++ /dev/null @@ -1,2364 +0,0 @@ -import colossalai -import torch -import copy -from typing import List, Callable, Any, Tuple, Dict, Iterable - -from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name -from torch.fx.graph import ( - _Namespace, - PythonCode, - _custom_builtins, - _is_from_torch, - _format_target, - magic_methods, - CodeGen, - _origin_type_map, - inplace_methods, - _CustomBuiltin, -) -from colossalai.fx.profiler import ( - calculate_fwd_out, - calculate_fwd_tmp, - parameter_size, - activation_size, -) - -CODEGEN_AVAILABLE = True -__all__ = ["ChunkCodeGen"] - - -def _delete_free_var_from_last_use(user_to_last_uses): - for key, value in user_to_last_uses.items(): - for n in value: - if n.op == "placeholder": - user_to_last_uses[key].remove(n) - - -def _get_node_shape(node): - if hasattr(node.meta["tensor_meta"], "shape"): - return node.meta["tensor_meta"].shape - return None - - -def _is_non_compute_node(node): - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -def _is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -def _is_non_compute_node_except_placeholder_output(node): - if any(i in node.op for i in ["get_attr"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - -class IndexTracer(object): - def __init__(self, node_list) -> None: - self.node_list = node_list - self.idx_trace_list = self._init_idx_trace_list() - self.idx_trace_equal = [] - self.idx_view_list = {} - self.idx_count = -1 - self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} - - def _init_idx_trace_list(self): - idx_trace_list = [] - for n in self.node_list: - if _get_node_shape(n) != None: - cur_trace = { - "idx": [None for _ in range(len(_get_node_shape(n)))], - "compute": [[] for _ in range(len(_get_node_shape(n)))], - "source": [{} for _ in range(len(_get_node_shape(n)))], - } - else: - cur_trace = {"idx": [], "compute": [], "source": []} - idx_trace_list.append(cur_trace) - return idx_trace_list - - def _add_index(self): - """ - Update the count and return it. To record the idx number. - - Returns: - idx_count: int - """ - self.idx_count += 1 - return self.idx_count - - def _del_dim(self, idx, dim_idx): - self.idx_trace_list[idx]["idx"].pop(dim_idx) - self.idx_trace_list[idx]["compute"].pop(dim_idx) - self.idx_trace_list[idx]["source"].pop(dim_idx) - - def _add_dim(self, node_idx, dim_idx): - self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) - self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) - self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) - - def _transform_index(self, node, node_dim): - node_idx = self._find_idx_trace_from_node(node) - dims = list(range(len(node_idx))) - return dims[node_dim] - - def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): - node_from_dim = self._transform_index(node_from, node_from_dim) - node_to_dim = self._transform_index(node_to, node_to_dim) - node_from_trace = self._find_trace_from_node(node_from) - node_to_trace = self._find_trace_from_node(node_to) - node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] - node_to_trace["compute"][node_to_dim] = copy.deepcopy( - node_from_trace["compute"][node_from_dim] - ) - self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) - - def _inherit_all_computation(self, node_from, node_to): - node_from_compute = self._find_compute_trace_from_node(node_from) - node_to_compute = self._find_compute_trace_from_node(node_to) - assert len(node_from_compute) == len(node_to_compute) - for i in range(len(node_from_compute)): - self._add_source(node_from, i, node_to, i) - node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - - def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): - node_from_dim = self._transform_index(node_from, node_from_dim) - node_from_trace_source = self._find_source_trace_from_node(node_from) - node_to_dim = self._transform_index(node_to, node_to_dim) - node_to_trace_source = self._find_source_trace_from_node(node_to) - node_from_idx = _find_idx_by_name(node_from.name, self.node_list) - if init: - node_to_trace_source[node_to_dim] = {} - # add dim to cur new source - if node_from_idx not in node_to_trace_source[node_to_dim]: - node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] - else: - if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: - node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) - # update inputs source - for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): - if node_idx not in node_to_trace_source[node_to_dim]: - node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) - else: - for d in node_dim: - if d not in node_to_trace_source[node_to_dim][node_idx]: - node_to_trace_source[node_to_dim][node_idx].append(d) - - def _mark_computation_from_node(self, node_from, node_to, exclude=None): - if exclude == None: - exclude = [] - else: - exclude = [self._transform_index(node_to, i) for i in exclude] - node_from_compute = self._find_compute_trace_from_node(node_from) - node_to_compute = self._find_compute_trace_from_node(node_to) - # assert len(node_from_compute) == len(node_to_compute) - for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): - if self._transform_index(node_to, i) in exclude: - continue - self._add_source(node_from, i, node_to, i) - for j in node_from_compute[i]: - if j not in node_to_compute[i]: - node_to_compute[i].append(j) - - def _mark_idx_equal(self, node1, dim1, node2, dim2): - """ - Mark 2 index to be equal. - - Args: - idx1 (int): index count. - idx2 (int): index count. - """ - # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) - # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) - # if node1_idx > node2_idx: - # self._add_source(node2, dim2, node1, dim1) - # else: - # self._add_source(node1, dim1, node2, dim2) - - def _mark_computation(self, node, idx, dim): - """ - Mark some dims of node as computed. - - Args: - node (node) - idx (int): node index - dim (list or int): dims to be marked as computed - """ - if isinstance(dim, int): - dim = [dim] - dims = list(range(len(_get_node_shape(node)))) - for d in dim: - cur_dim = dims[d] - if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: - self.idx_trace_list[idx]["compute"][cur_dim].append(idx) - - def _find_trace_from_node(self, node): - """ - Find node idx and compute trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - node_dict = self.idx_trace_list[node_idx] - return node_dict - - def _find_source_trace_from_node(self, node): - """ - Find node source trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - node_dict = self.idx_trace_list[node_idx] - return node_dict["source"] - - def _find_idx_trace_from_node(self, node): - """ - Find node idx trace by the node. - - Args: - node (node) - Returns: - idx (list): idx of the node - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - return self.idx_trace_list[node_idx]["idx"] - - def _find_compute_trace_from_node(self, node): - """ - Find node compute trace by the node. - - Args: - node (node) - Returns: - compute (list): computed idx of the node. - """ - node_idx = _find_idx_by_name(node.name, self.node_list) - return self.idx_trace_list[node_idx]["compute"] - - def _assign_index_as_input(self, node, node_idx, input_node=None): - """ - Assign node's trace as its input node. - - Args: - node (node) - node_idx (int) - """ - if input_node == None: - input_node = node.args[0] - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] - - new_idx_trace = copy.deepcopy(input_node_idx_trace) - self.idx_trace_list[node_idx]["idx"] = new_idx_trace - - self._inherit_all_computation(input_node, node) - - def _assign_all_index(self, node, node_idx): - """ - Add new index for all node's dims. - - Args: - node (node) - node_idx (int) - """ - shape = node.meta["tensor_meta"].shape - new_trace = [] - for _ in shape: - new_trace.append(self._add_index()) - self.idx_trace_list[node_idx]["idx"] = new_trace - - def _assign_transpose_index(self, node, node_idx): - """ - Assign index for transpose op. - 1. swap input's dim according to transpose args - 2. inherit input's computation - - Args: - node (node) - node_idx (int) - """ - input_node = node.args[0] - tranpose_dim = node.args[1:] - - self._assign_index_as_input(node, node_idx, input_node) - self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) - self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) - - def _assign_permute_index(self, node, node_idx): - """ - Assign index for permute op. - 1. swap input's dim according to permute args - 2. inherit input's computation - - Args: - node (node) - node_idx (int) - """ - permute_dim = node.args[1:] - input_node = node.args[0] - - self._assign_index_as_input(node, node_idx, input_node) - for idx, d in enumerate(permute_dim): - self._inherit_index(input_node, d, node, idx) - - def _assign_linear_index(self, node, node_idx): - """ - Assign index for linear op. - 1. copy trace from input node and change last index accroding to weight - 2. mark equal for input node last index, weight first dim and bias dim. - 3. inherit input's computation, mark computation for last dim. - - Args: - node (node) - node_idx (int) - """ - if len(node.args) == 2: - input_node, weight = node.args - bias = None - else: - input_node, weight, bias = node.args - - self._assign_index_as_input(node, node_idx) - self._inherit_index(weight, 1, node, -1) - - self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(input_node, -1, weight, 0) - - if bias: - self._mark_idx_equal(input_node, -1, bias, 0) - - def _assign_matmul_index(self, node, node_idx): - """ - Assign index for matmul op. - 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) - 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. - 3. inherit matmul_left and matmul_right computation, mark computation for last dim. - - Args: - node (node) - node_idx (int) - """ - matmul_left, matmul_right = node.args - - assert len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right)) - self._assign_index_as_input(node, node_idx, matmul_left) - self._inherit_index(matmul_right, -1, node, -1) - - self._mark_computation_from_node(matmul_right, node, [-1, -2]) - self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(matmul_left, -1, matmul_right, -2) - - def _assign_layernorm_index(self, node, idx): - """ - Assign index for layernorm op. - 1. assign index as input node - 2. inherit computation and mark last 2 dims as computed. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [-1]) - - def _assign_elementwise_index(self, node, idx): - """ - Assign index for element-wise op (eg. relu sigmoid add mul). - 1. assign index as input node - 2. inherit computation from all input nodes. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - nodes_in = [] - for node_in in node.args: - if type(node_in) == type(node): - nodes_in.append(node_in) - self._mark_computation_from_node(node_in, node) - assert len(nodes_in) <= 2 - if len(nodes_in) == 2: - node_in0_shape = _get_node_shape(nodes_in[0]) - node_in1_shape = _get_node_shape(nodes_in[1]) - for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): - if node_in0_shape[i] == node_in1_shape[i]: - self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) - - def _assgin_no_change_index(self, node, idx): - self._assign_index_as_input(node, idx) - for node_in in node.args: - if type(node_in) == type(node): - self._mark_computation_from_node(node_in, node) - - def _assign_einsum_index(self, node, idx): - """ - Assign index for einsum op. - - Args: - node (node) - node_idx (int) - """ - patterns = node.args[0] - input_nodes = node.args[1:] - - patterns = patterns.replace(" ", "") - left, right = patterns.split("->") - left = left.split(",") - - all_index = [] - for i in left: - for c in i: - all_index.append(c) - all_index = set(all_index) - free_index = set([i for i in right]) - sum_index = all_index - free_index - - for right_idx, right_indice in enumerate(right): - for left_idx, left_str in enumerate(left): - if right_indice in left_str: - source_idx = left_str.index(right_indice) - self._inherit_index( - input_nodes[left_idx], source_idx, node, right_idx - ) - - # for i in sum_index: - # for left_idx, left_str in enumerate(left): - # if i in left_str: - # self._mark_computation(node, idx, left_str.index(i)) - # break - - def _assign_softmax_index(self, node, idx): - """ - Assign index for softmax op. - 1. assign index as input node - 2. inherit computation and mark softmax dim as computed. - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, idx) - self._mark_computation(node, idx, [node.kwargs["dim"]]) - - def _assign_unsqueeze_index(self, node, node_idx): - """ - Assign index for unsqueeze op. - 1. assign new index for unsqueeze dim - - Args: - node (node) - node_idx (int) - """ - self._del_dim(node_idx, -1) - self._assign_index_as_input(node, node_idx) - self._add_dim(node_idx, node.args[1]) - - def _assign_dropout_index(self, node, node_idx): - """ - Assign index for unsqueeze op. - 1. assign new index for unsqueeze dim - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, node_idx) - - def _assign_ones_like_index(self, node, node_idx): - """ - Assign index for oneslike op. - 1. assign new index for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_all_index(node, node_idx) - - def _assign_view_reshape_index(self, node, node_idx): - """ - Assign index for view and reshape op. - 1. get origin shape and target shape by meta info. - 2. compute the real value of -1 in target shape. - 3. determine changed dim, and assgin index for generated dim. - 4. log changed dim and generated dim for restore - 5. inherit computation. - 6. TODO: look into view list to see whether the view is associated with other, - if so assgin equal dim according to previous view. - - Args: - node (node) - node_idx (int) - """ - # get data, turn into number - origin_node = node.args[0] - origin_shape = origin_node.meta["tensor_meta"].shape - target_shape = [] - for i in range(1, len(node.args)): - if isinstance(node.args[i], int): - target_shape.append(node.args[i]) - else: - target_shape.append(node.args[i].meta["fwd_out"][0]) - - # compute the value of -1 - if -1 in target_shape: - origin_product = 1 - for i in origin_shape: - origin_product *= i - target_product = -1 - for i in target_shape: - target_product *= i - shape_idx = target_shape.index(-1) - target_shape[shape_idx] = origin_product // target_product - - # determine changed dim - len_diff = len(origin_shape) - len(target_shape) - if len_diff == 1: - # dim merge - dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] - dim_to = [dim_equal.index(False)] - dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] - self._add_dim(node_idx, -1) - elif len_diff == -1: - # dim expand - dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] - dim_from = [dim_equal.index(False)] - dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] - self._del_dim(node_idx, -1) - else: - raise NotImplementedError( - "shape" - + str(origin_shape) - + "and" - + str(target_shape) - + "view not implemented" - ) - - # get new index - origin_trace = self._find_idx_trace_from_node(origin_node) - self._assign_index_as_input(node, node_idx, origin_node) - dim_from.reverse() - for i in dim_from: - self._del_dim(node_idx, i) - for i in dim_to: - self._add_dim(node_idx, i) - - # inherit computation - compute_log = self._find_compute_trace_from_node(origin_node) - for i in dim_from: - if origin_trace[i] in compute_log: - for j in dim_to: - self._mark_computation(node, node_idx, [j]) - break - - # log view, not used now - view_dict = { - "idx_from": [origin_trace[i] for i in dim_from], - "dim_from": dim_from, - "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], - "dim_to": dim_to, - } - self.idx_view_list[node] = view_dict - - def _merge_equal_idx(self): - idx_equal = copy.deepcopy(self.idx_trace_equal) - idx_equal.reverse() - for idx in idx_equal: - merge_to = min(idx) - merge_from = max(idx) - for trace in self.idx_trace_list: - if merge_from in trace["idx"]: - trace["idx"] = [ - merge_to if i == merge_from else i for i in trace["idx"] - ] - - def trace_index(self): - for idx, node in enumerate(self.node_list): - if node.op == "placeholder": - self._assign_all_index(node, idx) - elif node.op == "call_method": - if "transpose" in node.name: - self._assign_transpose_index(node, idx) - elif "permute" in node.name: - self._assign_permute_index(node, idx) - elif "view" in node.name or "reshape" in node.name: - self._assign_view_reshape_index(node, idx) - elif "unsqueeze" in node.name: - self._assign_unsqueeze_index(node, idx) - elif any(i in node.name for i in ["to", "contiguous"]): - self._assgin_no_change_index(node, idx) - else: - raise NotImplementedError(node.name, "method not implemented yet!") - elif node.op == "call_function": - if "linear" in node.name: - self._assign_linear_index(node, idx) - elif "matmul" in node.name: - self._assign_matmul_index(node, idx) - elif "softmax" in node.name: - self._assign_softmax_index(node, idx) - elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): - self._assign_elementwise_index(node, idx) - elif "ones_like" in node.name: - self._assign_ones_like_index(node, idx) - elif "dropout" in node.name: - self._assign_dropout_index(node, idx) - elif "einsum" in node.name: - self._assign_einsum_index(node, idx) - elif "getattr" in node.name: - continue # get attr like shape - elif "getitem" in node.name: - continue # get item in list - else: - raise NotImplementedError( - node.name, "function not implemented yet!" - ) - elif node.op == "call_module": - if any(n in node.name for n in ["layernorm", "norm"]): - self._assign_layernorm_index(node, idx) - else: - raise NotImplementedError(node.name, "module not implemented yet!") - elif node.op == "get_attr": - self._assign_all_index(node, idx) # get param - elif node.op == "output": - continue - else: - raise NotImplementedError(node.op, "op not implemented yet!") - # self._merge_equal_idx() - - def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): - """ - Check 2 given index: one index should be source of the other - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - start_node_idx = _find_idx_by_name(start_node.name, self.node_list) - end_node_trace = self._find_trace_from_node(end_node) - end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) - for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and start_dim in node_dim: - return True - # it means we meet a node outside the loop, and the node is not input node - if node_idx < start_idx: - return False - return False - - def check_index_compute(self, start_idx, end_dim, end_node, end_idx): - """ - Check 2 given index: check they haven't been computed in the source trace. - Args: - start_idx(int): start node chunk dim - start_node(node): start node - end_idx(int): end node chunk dim - end_node(node): end node - - Returns: - bool: True if check pass - """ - end_node_trace = self._find_trace_from_node(end_node) - end_node_compute = end_node_trace["compute"][end_dim] - if any(start_idx <= i <= end_idx for i in end_node_compute): - return False - return True - - def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self._find_source_trace_from_node(node_from) - dim_source = node_from_source[node_from_dim] - node_to_idx = _find_idx_by_name(node_to.name, self.node_list) - for k, v in dim_source.items(): - if k == node_to_idx: - return v - return None - - def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(_get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): - return node_dim - return None - - def check_index_duplicate(self, chunk_infos, return_dim=False): - input_dim_after_node = {} - for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): - for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) - if inherit_dim: - input_dim_after_node[k] = inherit_dim - - for node in self.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: - if _is_non_compute_node_except_placeholder(node): - continue - count = 0 - duplicate_dims = [] - node_trace_source = self._find_source_trace_from_node(node) - for node_dim in range(len(_get_node_shape(node))): - duplicate_dim = [] - duplicate_flag = False - dim_source = node_trace_source[node_dim] - for k, v in dim_source.items(): - if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] in v: - duplicate_flag = True - duplicate_dim.append((k, v)) - duplicate_dims.append(duplicate_dim) - if duplicate_flag: - count += 1 - - if count > 1: - if return_dim: - return False, duplicate_dims - else: - return False - if return_dim: - return True, None - else: - return True - - def _assgin_single_node_flow( - self, - arg_node, - start_idx, - end_idx, - cur_node_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ): - arg_idx = _find_idx_by_name(arg_node.name, self.node_list) - # arg in chunk range or be inputs - if not (start_idx <= arg_idx < end_idx): - return True - - # find arg dim - if cur_node_dim is not None: - # dim is computed - if arg_idx in cur_node_compute[cur_node_dim]: - return False - if arg_idx not in cur_node_source[cur_node_dim]: - arg_dim = None - else: - arg_dim = cur_node_source[cur_node_dim][arg_idx][0] - else: - arg_dim = None - - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) - - # if already in node_info, arg dim must be same - if arg_node in all_node_info: - if all_node_info[arg_node]["chunk_dim"] != arg_dim: - return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) - # else add it to list - else: - all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} - - next_node_list.append(arg_node) - return True - - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - cur_node_list = [self.node_list[end_idx]] # start from the last node - all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} - - while len(cur_node_list) > 0: - next_node_list = [] - - for cur_node in cur_node_list: - # get cur node info - cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] - cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) - if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node(cur_node) - cur_node_source = self._find_source_trace_from_node(cur_node) - else: - cur_node_compute = cur_node_source = None - - # get all valid args - arg_list = [] - for arg in cur_node.args: - if type(arg) != type(cur_node): - continue - if _is_non_compute_node(arg): - continue - arg_list.append(arg) - flow_flag = self._assgin_single_node_flow( - arg, - start_idx, - end_idx, - cur_node_chunk_dim, - cur_node_compute, - cur_node_source, - cur_node_fix_dim, - all_node_info, - next_node_list, - ) - if flow_flag == False: - return None - - if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): - for arg in arg_list: - if not ( - start_idx - <= _find_idx_by_name(arg.name, self.node_list) - < end_idx - ): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = _get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif "einsum" in cur_node.name: - pass - elif "matmul" in cur_node.name: - pass - else: - raise NotImplementedError() - cur_node_list = next_node_list - - inputs_dim = [] - remove_inputs = [] - for input_node in inputs: - input_dict = {} - input_node_idx = _find_idx_by_name(input_node.name, self.node_list) - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - if start_idx <= user_idx <= end_idx: - chunk_dim = all_node_info[user]["chunk_dim"] - if chunk_dim is not None: - user_source = self._find_source_trace_from_node(user)[chunk_dim] - if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] - else: - return None - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - for i in remove_inputs: - if i in inputs: - inputs.remove(i) - - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "node_chunk_dim": all_node_info, - "args": {}, - } - - # move useless nodes ahead of loop - # get all possible prepose nodes - maybe_prepose_nodes = [] - for node, node_info in all_node_info.items(): - if node_info["chunk_dim"] is None: - maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, self.node_list), - reverse=True, - ) # from last node to first node - prepose_nodes = [] - # set every node as root, search its args, if all legal, turn root and args as prepose nodes - while len(maybe_prepose_nodes) > 0: - tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] - tmp_cur_related_prepose_nodes = [] - prepose_flag = True - - # loop cur node's all arg until out of chunk - while len(tmp_cur_prepose_nodes) > 0: - if prepose_flag == False: - break - tmp_next_prepose_nodes = [] - tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) - for cur_prepose_node in tmp_cur_prepose_nodes: - if prepose_flag == False: - break - for cur_prepose_node_arg in cur_prepose_node.args: - if type(cur_prepose_node_arg) != type(cur_prepose_node): - continue - # out of loop - if not ( - start_idx - <= _find_idx_by_name( - cur_prepose_node_arg.name, self.node_list - ) - < end_idx - ): - continue - # compute op in loop - elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - else: - prepose_flag = False - break - # non compute op - else: - tmp_next_prepose_nodes.append(cur_prepose_node_arg) - tmp_cur_prepose_nodes = tmp_next_prepose_nodes - - if prepose_flag == False: - maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) - continue - else: - for n in tmp_cur_related_prepose_nodes: - if n not in prepose_nodes: - prepose_nodes.append(n) - if n in maybe_prepose_nodes: - maybe_prepose_nodes.remove(n) - # sort by index - prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list)) - chunk_info["args"]["prepose_nodes"] = prepose_nodes - - # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.node_list[start_idx : end_idx + 1] - # also need to get some prepose node's arg out of non_chunk_inputs - for n in prepose_nodes: - chunk_node_list.remove(n) - non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - # reassgin reshape size, some size may have changed due to chunk - chunk_info = self._reassgin_reshape_size(chunk_info) - - return chunk_info - - def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info["region"] - reshape_size = {} - chunk_shape = _get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: - if any(i in node.name for i in ["reshape", "view"]): - reshape_args = node.args[1:] - reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] - reshape_size[node.name] = {} - for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log["dim_to"]: - continue - if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) - chunk_info["reshape_size"] = reshape_size - return chunk_info - - def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.node_list))} - - chunk_region_start = chunk_info["region"][0] - chunk_region_end = chunk_info["region"][1] - chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] - chunk_prepose_nodes_idx = [ - _find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes - ] - # put prepose nodes ahead - for idx, n in enumerate(chunk_prepose_nodes): - n_idx = chunk_prepose_nodes_idx[idx] - reorder_map[n_idx] = chunk_region_start + idx - # put other nodes after prepose nodes - for n in self.node_list[chunk_region_start : chunk_region_end + 1]: - if n in chunk_prepose_nodes: - continue - n_idx = _find_idx_by_name(n.name, self.node_list) - pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) - reorder_map[n_idx] = n_idx + pos - - return reorder_map - - def _reorder_chunk_info(self, chunk_info, reorder_map): - # update chunk info - chunk_info["region"] = ( - chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), - chunk_info["region"][1], - ) - new_inputs_dim = [] - for idx, input_dim in enumerate(chunk_info["inputs_dim"]): - new_input_dim = {} - for k, v in input_dim.items(): - new_input_dim[reorder_map[k]] = v - new_inputs_dim.append(new_input_dim) - chunk_info["inputs_dim"] = new_inputs_dim - return chunk_info - - def _update_all_reorder_map(self, reorder_map): - for origin_idx, map_idx in self.all_reorder_map.items(): - self.all_reorder_map[origin_idx] = reorder_map[map_idx] - - def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.node_list[old_idx] - self.node_list = new_node_list - - def _reorder_idx_trace(self, reorder_map): - # reorder list - new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] - for old_idx, new_idx in reorder_map.items(): - new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] - self.idx_trace_list = new_idx_trace_list - # update compute - for idx_trace in self.idx_trace_list: - compute = idx_trace["compute"] - for dim_compute in compute: - for idx, i in enumerate(dim_compute): - dim_compute[idx] = reorder_map[i] - # update source - for idx_trace in self.idx_trace_list: - source = idx_trace["source"] - for dim_idx, dim_source in enumerate(source): - new_dim_source = {} - for k, v in dim_source.items(): - new_dim_source[reorder_map[k]] = v - source[dim_idx] = new_dim_source - - def reorder_all(self, chunk_info): - if chunk_info is None: - return chunk_info - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return chunk_info - reorder_map = self._get_reorder_map(chunk_info) - self._update_all_reorder_map(reorder_map) - self._reorder_idx_trace(reorder_map) - self._reorder_self_node_list(reorder_map) - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return chunk_info - - def reorder_node_list(self, node_list): - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in self.all_reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - return new_node_list - - def tmp_reorder(self, node_list, chunk_info): - if len(chunk_info["args"]["prepose_nodes"]) == 0: - return node_list, chunk_info - reorder_map = self._get_reorder_map(chunk_info) - - # new tmp node list - new_node_list = [None for _ in range(len(node_list))] - for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = node_list[old_idx] - - chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) - return new_node_list, chunk_info - - -class MemoryEstimator(object): - def __init__(self, index_tracer: IndexTracer) -> None: - pass - - def _get_meta_node_size(self, x): - x = x.meta["tensor_meta"] - x = x.numel * torch.tensor([], dtype=x.dtype).element_size() - return x - - def _get_output_node(self, n): - fwd_out = { - x.uuid: x - for x in n.meta["fwd_out"] - if isinstance(x, torch.Tensor) and hasattr(x, "uuid") - } - out_size = activation_size(fwd_out) - out_node = [n.name] if out_size > 0 else [] - # if any(i in n.name for i in ['transpose', 'permute', 'view']): - # out_size = 0 - return out_size, out_node - - def _get_output_node_size(self, n): - return self._get_output_node(n)[0] - - def _add_active_node(self, n, active_list): - new_active = self._get_output_node(n)[1] - if n.op == "placeholder": - new_active.append(n.name) - for i in new_active: - if i not in active_list: - active_list.append(i) - - def _get_delete_node(self, user, user_to_last_uses, to_keep=None): - delete_size = 0 - delete_node = [] - if user.op not in ("output",): - nodes_to_delete = user_to_last_uses.get(user, []) - if to_keep is not None: - keep_list = [] - for n in nodes_to_delete: - if n.name in to_keep: - keep_list.append(n) - for n in keep_list: - if n in nodes_to_delete: - nodes_to_delete.remove(n) - if len(nodes_to_delete): - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - for i in range(len(out_node)): - if out_node[i][0] > 0: - delete_node.append(out_node[i][1][0]) - elif nodes_to_delete[i].op == "placeholder": - delete_node.append(nodes_to_delete[i].name) - # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): - # delete_node.append(nodes_to_delete[i].name) - return delete_size, delete_node - - def _get_delete_node_size(self, user, user_to_last_uses, to_keep): - return self._get_delete_node(user, user_to_last_uses, to_keep)[0] - - def _remove_deactive_node(self, user, user_to_last_uses, active_list): - delete_node = self._get_delete_node(user, user_to_last_uses)[1] - for i in delete_node: - if i in active_list: - active_list.remove(i) - - def _get_chunk_inputs_size( - self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx - ): - nodes_to_delete = [] - for chunk_input in chunk_inputs + chunk_inputs_non_chunk: - chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [ - _find_idx_by_name(i.name, node_list) for i in chunk_input_users - ] - if all(i <= chunk_end_idx for i in chunk_input_users_idx): - if chunk_input not in nodes_to_delete: - nodes_to_delete.append(chunk_input) - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - return delete_size - - def _get_last_usr(self, nodes): - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - return user_to_last_uses - - def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): - mem = 0 - not_contiguous_ops = ["permute"] - inherit_contiguous_ops = ["transpose", "view"] - - if node.op == "call_function" and any( - n in node.name for n in ["matmul", "reshape"] - ): - for n in node.args: - if n in not_contiguous_list: - # matmul won't change origin tensor, but create a tmp copy - mem += self._get_output_node_size(n) - elif node.op == "call_module": - for n in node.args: - if n in not_contiguous_list: - # module will just make origin tensor to contiguous - if delete: - not_contiguous_list.remove(n) - elif node.op == "call_method" and any( - i in node.name for i in not_contiguous_ops - ): - if node not in not_contiguous_list: - not_contiguous_list.append(node) - return mem - - def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): - if node not in chunk_node_dim: - return 1.0 - node_shape = _get_node_shape(node) - chunk_dim = chunk_node_dim[node]["chunk_dim"] - if chunk_dim is None: - return 1.0 - else: - return float(chunk_size) / node_shape[chunk_dim] - - def _get_chunk_delete_node_size( - self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names - ): - # if any(j in user.name for j in ['transpose', 'permute', 'view']): - # return 0 - if user.op in ("placeholder", "output"): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - delete_size = 0 - for n in nodes_to_delete: - if n.name in chunk_inputs_names: - continue - delete_size += self._get_output_node_size(n) * chunk_ratio - return delete_size - - def _print_mem_log(self, log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end="") - if (idx + 1) % 3 == 0: - print("") - print("\n") - - def _print_compute_op_mem_log(self, log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - if n.op in ["placeholder", "get_attr", "output"]: - continue - if any(i in n.name for i in ["getitem", "getattr"]): - continue - print("%s:%.2f \t" % (n.name, l), end="") - if (idx + 1) % 3 == 0: - print("") - print("\n") - - def estimate_chunk_inference_mem( - self, - node_list, - chunk_infos=None, - print_mem=False, - ): - act_memory = 0.0 - act_memory_peak_log = [] - act_memory_after_node_log = [] - active_node_list = [] - active_node_list_log = [] - not_contiguous_list = [] - user_to_last_uses = self._get_last_usr(node_list) - user_to_last_uses_no_free_var = self._get_last_usr(node_list) - _delete_free_var_from_last_use(user_to_last_uses_no_free_var) - - use_chunk = True if chunk_infos is not None else False - chunk_within = False - chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem - chunk_inputs_names = [] - - if use_chunk: - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] - chunk_outputs = [i["outputs"][0] for i in chunk_infos] - chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] - chunk_sizes = [ - i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos - ] - - for idx, node in enumerate(node_list): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor - if use_chunk and idx in chunk_starts: - chunk_within = True - chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size( - chunk_outputs[chunk_region_idx] - ) / (1024**2) - - # determine chunk ratio for current node - if chunk_within: - chunk_ratio = self._get_chunk_ratio( - node, - chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx], - ) - - # if node is placeholder, just add the size of the node - if node.op == "placeholder": - act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) - act_memory_peak_log.append(act_memory) - # skip output - elif node.op == "output": - continue - # no change for non compute node - elif _is_non_compute_node_except_placeholder(node): - act_memory_peak_log.append(act_memory) - # node is a compute op - # calculate tmp, output node and delete node memory - else: - # forward memory - # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose - act_memory += ( - self._get_contiguous_memory(node, not_contiguous_list) - * chunk_ratio - / (1024**2) - ) - act_memory += ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) - # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= ( - self._get_contiguous_memory(node, not_contiguous_list, delete=True) - * chunk_ratio - / (1024**2) - ) - # delete unused vars not in chunk_input_list - # we can't delete input nodes until chunk ends - if chunk_within: - act_memory -= self._get_chunk_delete_node_size( - node, - user_to_last_uses_no_free_var, - chunk_ratio, - chunk_inputs_names, - ) / (1024**2) - else: - act_memory -= self._get_delete_node_size( - node, user_to_last_uses_no_free_var, chunk_inputs_names - ) / (1024**2) - - # log active node, only effective without chunk - self._add_active_node(node, active_node_list) - self._remove_deactive_node(node, user_to_last_uses, active_node_list) - - # if node in chunk end nodes, restore chunk settings - if use_chunk and idx in chunk_ends: - act_memory -= ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) - act_memory -= self._get_chunk_inputs_size( - chunk_inputs[chunk_region_idx], - chunk_inputs_non_chunk[chunk_region_idx], - node_list, - chunk_regions[chunk_region_idx][1], - ) / (1024**2) - chunk_within = False - chunk_ratio = 1 - chunk_region_idx = None - - act_memory_after_node_log.append(act_memory) - active_node_list_log.append(copy.deepcopy(active_node_list)) - - if print_mem: - print("with chunk" if use_chunk else "without chunk") - # self._print_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_mem_log(act_memory_after_node_log, node_list, "after") - self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_compute_op_mem_log( - # act_memory_after_node_log, node_list, "after" - # ) - - # param_memory = parameter_size(gm) - # all_memory = act_memory + param_memory - return act_memory_peak_log, act_memory_after_node_log, active_node_list_log - - -class ChunkSelector(object): - def __init__( - self, - index_tracer: IndexTracer, - memory_estimator: MemoryEstimator, - max_memory=None, - ): - self.index_tracer = index_tracer - self.memory_estimator = memory_estimator - if max_memory is not None: - self.stratge = "fit_memory" - self.max_memory = max_memory # MB - else: - self.stratge = "min_memory" - - def _select_best_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - if self.stratge == "min_memory": - best_region = self._select_min_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) - elif self.stratge == "fit_memory": - best_region = self._select_fit_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) - else: - raise RuntimeError() - return best_region - - def _select_fit_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - # stop chunk if max memory satisfy memory limit - if max(mem_peak) < self.max_memory: - return None - - # remove illegal regions - illegal_regions = [] - for i in possible_chunk_regions: - if not self._is_legal_region(i, chunk_infos): - illegal_regions.append(i) - for i in illegal_regions: - if i in possible_chunk_regions: - possible_chunk_regions.remove(i) - - if len(possible_chunk_regions) == 0: - return None - - # get mem for chunk region - regions_dict = [] - for region in possible_chunk_regions: - cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( - self.index_tracer.node_list, cur_region - ) - cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] - cur_chunk_region_max_peak = max(cur_chunk_region_peak) - if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) - # no region found - if len(regions_dict) == 0: - raise RuntimeError("Search failed. Try a larger memory threshold.") - - # select the min chunk len - chunk_len = [i["chunk_len"] for i in regions_dict] - best_region_idx = chunk_len.index(min(chunk_len)) - best_region = regions_dict[best_region_idx] - - # get max chunk size - best_region = self._get_fit_chunk_size(best_region, chunk_infos) - return best_region - - def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): - chunk_size = 1 - reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] - reorder_chunk_info["chunk_size"] = chunk_size - cur_chunk_max_mem = 0 - # search a region - while cur_chunk_max_mem < self.max_memory: - chunk_size *= 2 - reorder_chunk_info["chunk_size"] = chunk_size - cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[ - reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] - + 1 - ] - ) - # search exact size - chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search( - chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos - ) - return chunk_info - - def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): - if l >= 16: - gap = 4 - else: - gap = 1 - chunk_info = chunk_region_dict["reorder_chunk_info"] - while r >= l + gap: - mid = int((l + r) / 2 + 0.5) - chunk_info["chunk_size"] = mid - cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] - ) - if cur_chunk_max_mem >= self.max_memory: - r = mid - gap - else: - l = mid + gap - return l - - def _get_compute_node_num(self, start, end): - count = 0 - for i in self.index_tracer.node_list[start : end + 1]: - if not _is_non_compute_node(i): - count += 1 - return count - - def _select_min_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): - # remove illegal regions - illegal_regions = [] - for i in possible_chunk_regions: - if not self._is_legal_region(i, chunk_infos): - illegal_regions.append(i) - for i in illegal_regions: - if i in possible_chunk_regions: - possible_chunk_regions.remove(i) - - if len(possible_chunk_regions) == 0: - return None - - # get mem for chunk region - regions_dict = [] - for region in possible_chunk_regions: - cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder( - self.index_tracer.node_list, cur_region - ) - cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] - cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) - - # select the min mem - chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] - best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) - best_region = regions_dict[best_region_idx]["chunk_info"] - if best_region is not None: - best_region["chunk_size"] = 1 - return best_region - - def _is_legal_region(self, cur_chunk_info, chunk_infos): - (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] - if cur_chunk_info in chunk_infos: - return False - if chunk_region_end < chunk_region_start: - return False - for i in chunk_infos: - region = i["region"] - if not ( - (chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0]) - ): - return False - return True - - -class ChunkRegionSearch(object): - def __init__(self, gm, max_memory=None) -> None: - self.gm = gm - self.index_tracer = IndexTracer(list(gm.graph.nodes)) - self.index_tracer.trace_index() - self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, max_memory=max_memory - ) - - def _find_peak_node(self, mem_peak): - max_value = max(mem_peak) - max_idx = mem_peak.index(max_value) - return max_idx - - def _get_free_var(self): - free_var_idx = [] - for idx, n in enumerate(self.index_tracer.node_list): - if n.op == "placeholder": - free_var_idx.append(idx) - return free_var_idx - - def _get_min_free_var(self, active_node_list, free_vars): - min_len = 999 - for idx, n in enumerate(active_node_list): - if idx in free_vars: - continue - if len(n) < min_len: - min_len = len(n) - return min_len - - def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): - free_vars = self._get_free_var() - free_var_num = len(free_vars) - active_node_num = [len(i) for i in active_node] - min_active_node_num = min(active_node_num[free_var_num:]) - threshold = max(free_var_num, min_active_node_num) - - # from peak_node to free_var - inside_flag = False - chunk_region_start = free_var_num - for i in range(peak_node, -1, -1): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: - chunk_region_start = i + 1 - break - - # from peak_node to len-2 - inside_flag = False - chunk_region_end = len(active_node) - 1 - for i in range(peak_node, len(active_node)): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: - chunk_region_end = i - break - - for i in chunk_regions: - region = i["region"] - if chunk_region_start >= region[0] and chunk_region_end <= region[1]: - return None - elif ( - region[0] <= chunk_region_start <= region[1] - and chunk_region_end > region[1] - ): - chunk_region_start = region[1] + 1 - elif ( - region[0] <= chunk_region_end <= region[1] - and chunk_region_start < region[0] - ): - chunk_region_end = region[0] - 1 - return chunk_region_start, chunk_region_end - - def _is_not_compute(self, trace, chunk_range, dim_idx): - if trace["idx"][dim_idx] not in trace["compute"]: - return True - if trace["idx"][dim_idx] in trace["compute"] and all( - i < chunk_range[0] or i > chunk_range[1] - for i in trace["compute"][trace["idx"][dim_idx]] - ): - return True - return False - - def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): - start_traces = input_trace[start_idx] - end_trace = output_trace[end_idx] - end_node = self.index_tracer.node_list[end_idx] - chunk_infos = [] - for end_dim, _ in enumerate(end_trace["idx"]): - if len(start_traces) > 1: - continue - for start_node, start_trace in start_traces.items(): - for start_dim, _ in enumerate(start_trace["idx"]): - # dim size cannot be 1 - if ( - _get_node_shape(end_node)[end_dim] == 1 - or _get_node_shape(start_node)[start_dim] == 1 - ): - continue - # check index source align - if not self.index_tracer.check_index_source( - start_dim, start_node, start_idx, end_dim, end_node - ): - continue - # check index copmute - if not self.index_tracer.check_index_compute( - start_idx, end_dim, end_node, end_idx - ): - continue - # flow search - chunk_info = self.index_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim - ) - if chunk_info is None: - continue - # check index copmute - if not self.index_tracer.check_index_duplicate(chunk_info): - continue - chunk_infos.append(chunk_info) - return chunk_infos - - def _search_possible_chunk_regions(self, max_chunk_region, peak_node): - possible_chunk_region = [] - output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) - input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.index_tracer.node_list): - cur_trace = {} - for arg in n.args: - if type(arg) == type(n) and not _is_non_compute_node_except_placeholder( - arg - ): - cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) - input_trace.append(cur_trace) - - for start_idx in range(max_chunk_region[0], peak_node + 1): - for end_idx in range(peak_node, max_chunk_region[1] + 1): - # skip non compute nodes - if _is_non_compute_node( - self.index_tracer.node_list[start_idx] - ) or _is_non_compute_node(self.index_tracer.node_list[end_idx]): - continue - - # select free dim - chunk_info = self._find_free_dim( - input_trace, output_trace, start_idx, end_idx - ) - if len(chunk_info) > 0: - possible_chunk_region.extend(chunk_info) - return possible_chunk_region - - def _step_search(self, mem_peak, active_node, chunk_regions): - peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region( - active_node, peak_node, chunk_regions - ) - if max_chunk_region == None: - return None - possible_chunk_regions = self._search_possible_chunk_regions( - max_chunk_region, peak_node - ) - best_chunk_region = self.chunk_selector._select_best_chunk_region( - possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak - ) - best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) - return best_chunk_region - - def _stop_search(self, init_mem_peak, mem_peak): - sorted_init_mem_peak = sorted(init_mem_peak) - if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: - return True - return False - - def search_region(self): - chunk_infos = [] - ( - init_mem_peak, - _, - active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list - ) - mem_peak = init_mem_peak - - while True: - chunk_info = self._step_search(mem_peak, active_node, chunk_infos) - if chunk_info is None: - break - chunk_infos.append(chunk_info) - - ( - mem_peak, - _, - active_node, - ) = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos - ) - if self._stop_search(init_mem_peak, mem_peak): - break - self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, chunk_infos, print_mem=True - ) - return chunk_infos - - -def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): - new_shape = "[" - for idx, i in enumerate(shape): - if idx == chunk_dim: - new_shape += "%s:%s + chunk_size" % (chunk_idx_name, chunk_idx_name) - else: - new_shape += ":" - new_shape += ", " - new_shape = new_shape[:-2] + "]" - return new_shape - - -def _gen_loop_start(chunk_input, chunk_output, chunk_ouput_dim, chunk_size=2): - input_node = chunk_input[0] - out_shape = _get_node_shape(chunk_output) - out_str = str(list(out_shape)) - context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" - % (out_str, input_node.name, input_node.name, chunk_size) - ) - context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) - return context - - -def _gen_loop_end( - chunk_inputs, chunk_non_compute_inputs, chunk_outputs, chunk_outputs_dim, node_list -): - chunk_outputs_name = chunk_outputs.name - chunk_outputs_idx = _find_idx_by_name(chunk_outputs_name, node_list) - chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim( - chunk_outputs_dim, "chunk_idx", chunk_output_shape - ) - context = " chunk_result%s = %s; %s = None\n" % ( - chunk_slice, - chunk_outputs_name, - chunk_outputs_name, - ) - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) - - # determine if its the last use for chunk input - for chunk_input in chunk_inputs + chunk_non_compute_inputs: - if all( - [ - _find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in chunk_input.users.keys() - ] - ): - context += "; %s = None" % chunk_input.name - - context += "\n" - return context - - -def _find_chunk_all_input_nodes(nodes: List[Node]): - """ - Find non-compute input and output node names. - input nodes are nodes used in the list - output nodes are nodes will use nodes in the list - """ - input_nodes = [] - for node in nodes: - for input_node in node._input_nodes.keys(): - if input_node not in nodes and input_node not in input_nodes: - input_nodes.append(input_node) - return input_nodes - - -def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): - """ - Find non-compute input and output node names. - input nodes are nodes used in the list - output nodes are nodes will use nodes in the list - """ - input_nodes = [] - output_nodes = [] - - # if a node has an input node which is not in the node list - # we treat that input node as the input of the checkpoint function - for node in nodes: - for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - and not _is_non_compute_node_except_placeholder(input_node) - ): - input_nodes.append(input_node) - - # if a node has a user node which is not in the node list - # we treat that user node as the node receiving the current node output - for node in nodes: - for output_node in node.users.keys(): - if ( - output_node not in nodes - and node not in output_nodes - and not _is_non_compute_node_except_placeholder_output(output_node) - ): - output_nodes.append(node) - - return input_nodes, output_nodes - - -def _find_idx_by_name(name, nodes_list): - for idx, node in enumerate(nodes_list): - if node.name == name: - return idx - raise RuntimeError("name %s not found in node list" % name) - - -def _replace_name(context, name_from, name_to): - patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] - for p in patterns: - source = p[0] + name_from + p[1] - target = p[0] + name_to + p[1] - if source in context: - context = context.replace(source, target) - return context - - -def _replace_reshape_size(context, node_name, reshape_size_dict): - if node_name not in reshape_size_dict: - return context - for size_name, size_value in reshape_size_dict[node_name].items(): - context = context.replace(size_name, size_value) - return context - - -def emit_code_with_chunk( - body, - nodes, - emit_node_func, - delete_unused_value_func, - chunk_region_search, - chunk_infos -): - """Emit code with nested activation checkpoint - When we detect some of the node.activation_checkpoint is a List, we will use - this function to emit the activation checkpoint codes. - - Args: - body: forward code - ckpt_func: checkpoint functions code - nodes: graph.nodes - emit_node_func: function to emit node - delete_unused_value_func: function to remove the unused value - """ - node_list = list(nodes) - - chunk_regions = [i["region"] for i in chunk_infos] - chunk_starts = [i[0] for i in chunk_regions] - chunk_ends = [i[1] for i in chunk_regions] - - chunk_inputs = [i["inputs"] for i in chunk_infos] - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] - - chunk_outputs = [i["outputs"][0] for i in chunk_infos] - chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] - - node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) - node_idx = 0 - region_idx = 0 - within_chunk_region = False - - while node_idx < len(node_list): - node = node_list[node_idx] - - if node_idx in chunk_starts: - within_chunk_region = True - region_idx = chunk_starts.index(node_idx) - body.append( - _gen_loop_start( - chunk_inputs[region_idx], - chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], - chunk_infos[region_idx]["chunk_size"], - ) - ) - - if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - # ones like - if "ones_like" in node.name: - meta_node = chunk_region_search.index_tracer.node_list[node_idx] - chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ - "chunk_dim" - ] - if _get_node_shape(meta_node)[chunk_dim] != 1: - source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node][ - "chunk_dim" - ] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", _get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) - body[-1] = _replace_reshape_size( - body[-1], node.name, chunk_infos[region_idx]["reshape_size"] - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) - else: - emit_node_func(node, body) - if node_idx not in chunk_inputs: - delete_unused_value_func(node, body, chunk_inputs_names) - - if node_idx in chunk_ends: - body.append( - _gen_loop_end( - chunk_inputs[region_idx], - chunk_inputs_non_chunk[region_idx], - chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], - node_list, - ) - ) - within_chunk_region = False - - node_idx += 1 - - -if CODEGEN_AVAILABLE: - - class ChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None): - super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) - # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) - self.chunk_infos = self.chunk_region_search.search_region() - - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return "()" - - typename = _type_repr(o) - - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - if hasattr(o, "__args__"): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename - - return f'{origin_typename}[{",".join(args)}]' - else: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ - return origin_typename - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, "_fields"): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - - args_s = ", ".join(_get_repr(a) for a in args) - kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) - if args_s and kwargs_s: - return f"{args_s}, {kwargs_s}" - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - _delete_free_var_from_last_use(user_to_last_uses) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body, to_keep=[]): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == "placeholder": - return - if user.op == "output": - body.append("\n") - return - nodes_to_delete = user_to_last_uses.get(user, []) - nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] - if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) - body.append(f"; {to_delete_str}\n") - else: - body.append("\n") - - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" - ) - if node.op == "placeholder": - assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) - raw_name = node.target.replace("*", "") - if raw_name != repr(node): - body.append(f"{repr(node)} = {raw_name}\n") - return - elif node.op == "call_method": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) - return - elif node.op == "call_function": - assert callable(node.target) - # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): - assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) - return - body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" - ) - if node.meta.get("is_wrapped", False): - wrapped_fns.setdefault(global_name) - return - elif node.op == "call_module": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" - ) - return - elif node.op == "get_attr": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" - ) - return - elif node.op == "output": - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f"node: {node.op} {node.target}") - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.chunk_region_search, - self.chunk_infos - ) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append("pass\n") - - if len(wrapped_fns) > 0: - wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] - ) - else: - wrap_stmts = "" - - if self._body_transformer: - body = self._body_transformer(body) - - for name, value in self.additional_globals(): - add_global(name, value) - - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = "".join(ckpt_func) + prologue - prologue = prologue - - code = "".join(body) - code = "\n".join(" " + line for line in code.split("\n")) - fn_code = f""" -{wrap_stmts} - -{prologue} -{code}""" - # print(fn_code) - return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py new file mode 100644 index 000000000..0d0825f25 --- /dev/null +++ b/colossalai/autochunk/chunk_region_search.py @@ -0,0 +1,211 @@ +from .index_tracer import IndexTracer +from .memory_estiamtor import MemoryEstimator +from .chunk_selector import ChunkSelector +import copy +from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape + + +class ChunkRegionSearch(object): + def __init__(self, gm, max_memory=None) -> None: + self.gm = gm + self.index_tracer = IndexTracer(list(gm.graph.nodes)) + self.index_tracer.trace_index() + self.memory_estimator = MemoryEstimator(self.index_tracer) + self.chunk_selector = ChunkSelector( + self.index_tracer, self.memory_estimator, max_memory=max_memory + ) + + def _find_peak_node(self, mem_peak): + max_value = max(mem_peak) + max_idx = mem_peak.index(max_value) + return max_idx + + def _get_free_var(self): + free_var_idx = [] + for idx, n in enumerate(self.index_tracer.node_list): + if n.op == "placeholder": + free_var_idx.append(idx) + return free_var_idx + + def _get_min_free_var(self, active_node_list, free_vars): + min_len = 999 + for idx, n in enumerate(active_node_list): + if idx in free_vars: + continue + if len(n) < min_len: + min_len = len(n) + return min_len + + def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): + free_vars = self._get_free_var() + free_var_num = len(free_vars) + active_node_num = [len(i) for i in active_node] + min_active_node_num = min(active_node_num[free_var_num:]) + threshold = max(free_var_num, min_active_node_num) + + # from peak_node to free_var + inside_flag = False + chunk_region_start = free_var_num + for i in range(peak_node, -1, -1): + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: + chunk_region_start = i + 1 + break + + # from peak_node to len-2 + inside_flag = False + chunk_region_end = len(active_node) - 1 + for i in range(peak_node, len(active_node)): + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: + chunk_region_end = i + break + + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif ( + region[0] <= chunk_region_start <= region[1] + and chunk_region_end > region[1] + ): + chunk_region_start = region[1] + 1 + elif ( + region[0] <= chunk_region_end <= region[1] + and chunk_region_start < region[0] + ): + chunk_region_end = region[0] - 1 + return chunk_region_start, chunk_region_end + + def _is_not_compute(self, trace, chunk_range, dim_idx): + if trace["idx"][dim_idx] not in trace["compute"]: + return True + if trace["idx"][dim_idx] in trace["compute"] and all( + i < chunk_range[0] or i > chunk_range[1] + for i in trace["compute"][trace["idx"][dim_idx]] + ): + return True + return False + + def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx): + start_traces = input_trace[start_idx] + end_trace = output_trace[end_idx] + end_node = self.index_tracer.node_list[end_idx] + chunk_infos = [] + for end_dim, _ in enumerate(end_trace["idx"]): + if len(start_traces) > 1: + continue + for start_node, start_trace in start_traces.items(): + for start_dim, _ in enumerate(start_trace["idx"]): + # dim size cannot be 1 + if ( + get_node_shape(end_node)[end_dim] == 1 + or get_node_shape(start_node)[start_dim] == 1 + ): + continue + # check index source align + if not self.index_tracer.check_index_source( + start_dim, start_node, start_idx, end_dim, end_node + ): + continue + # check index copmute + if not self.index_tracer.check_index_compute( + start_idx, end_dim, end_node, end_idx + ): + continue + # flow search + chunk_info = self.index_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim + ) + if chunk_info is None: + continue + # check index copmute + if not self.index_tracer.check_index_duplicate(chunk_info): + continue + chunk_infos.append(chunk_info) + return chunk_infos + + def _search_possible_chunk_regions(self, max_chunk_region, peak_node): + possible_chunk_region = [] + output_trace = copy.deepcopy(self.index_tracer.idx_trace_list) + input_trace = [] # trace of a node's input nodes + for _, n in enumerate(self.index_tracer.node_list): + cur_trace = {} + for arg in n.args: + if type(arg) == type(n) and not is_non_compute_node_except_placeholder( + arg + ): + cur_trace[arg] = self.index_tracer._find_trace_from_node(arg) + input_trace.append(cur_trace) + + for start_idx in range(max_chunk_region[0], peak_node + 1): + for end_idx in range(peak_node, max_chunk_region[1] + 1): + # skip non compute nodes + if is_non_compute_node( + self.index_tracer.node_list[start_idx] + ) or is_non_compute_node(self.index_tracer.node_list[end_idx]): + continue + + # select free dim + chunk_info = self._find_free_dim( + input_trace, output_trace, start_idx, end_idx + ) + if len(chunk_info) > 0: + possible_chunk_region.extend(chunk_info) + return possible_chunk_region + + def _step_search(self, mem_peak, active_node, chunk_regions): + peak_node = self._find_peak_node(mem_peak) + max_chunk_region = self._search_max_chunk_region( + active_node, peak_node, chunk_regions + ) + if max_chunk_region == None: + return None + possible_chunk_regions = self._search_possible_chunk_regions( + max_chunk_region, peak_node + ) + best_chunk_region = self.chunk_selector._select_best_chunk_region( + possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak + ) + best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) + return best_chunk_region + + def _stop_search(self, init_mem_peak, mem_peak): + sorted_init_mem_peak = sorted(init_mem_peak) + if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: + return True + return False + + def search_region(self): + chunk_infos = [] + ( + init_mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list + ) + mem_peak = init_mem_peak + + while True: + chunk_info = self._step_search(mem_peak, active_node, chunk_infos) + if chunk_info is None: + break + chunk_infos.append(chunk_info) + + ( + mem_peak, + _, + active_node, + ) = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos + ) + if self._stop_search(init_mem_peak, mem_peak): + break + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) + return chunk_infos + diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py new file mode 100644 index 000000000..f84322082 --- /dev/null +++ b/colossalai/autochunk/chunk_selector.py @@ -0,0 +1,221 @@ +from .index_tracer import IndexTracer +from .memory_estiamtor import MemoryEstimator +from .utils import is_non_compute_node + + +class ChunkSelector(object): + def __init__( + self, + index_tracer: IndexTracer, + memory_estimator: MemoryEstimator, + max_memory=None, + ): + self.index_tracer = index_tracer + self.memory_estimator = memory_estimator + if max_memory is not None: + self.stratge = "fit_memory" + self.max_memory = max_memory # MB + else: + self.stratge = "min_memory" + + def _select_best_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + if self.stratge == "min_memory": + best_region = self._select_min_memory_chunk_region( + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) + elif self.stratge == "fit_memory": + best_region = self._select_fit_memory_chunk_region( + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) + else: + raise RuntimeError() + return best_region + + def _select_fit_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # stop chunk if max memory satisfy memory limit + if max(mem_peak) < self.max_memory: + return None + + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + if cur_chunk_region_max_peak < self.max_memory: + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + # no region found + if len(regions_dict) == 0: + raise RuntimeError("Search failed. Try a larger memory threshold.") + + # select the min chunk len + chunk_len = [i["chunk_len"] for i in regions_dict] + best_region_idx = chunk_len.index(min(chunk_len)) + best_region = regions_dict[best_region_idx] + + # get max chunk size + best_region = self._get_fit_chunk_size(best_region, chunk_infos) + return best_region + + def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): + chunk_size = 1 + reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_max_mem = 0 + # search a region + while cur_chunk_max_mem < self.max_memory: + chunk_size *= 2 + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [reorder_chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[ + reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + + 1 + ] + ) + # search exact size + chunk_info = chunk_region_dict["chunk_info"] + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos + ) + return chunk_info + + def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): + if l >= 16: + gap = 4 + else: + gap = 1 + chunk_info = chunk_region_dict["reorder_chunk_info"] + while r >= l + gap: + mid = int((l + r) / 2 + 0.5) + chunk_info["chunk_size"] = mid + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + if cur_chunk_max_mem >= self.max_memory: + r = mid - gap + else: + l = mid + gap + return l + + def _get_compute_node_num(self, start, end): + count = 0 + for i in self.index_tracer.node_list[start : end + 1]: + if not is_non_compute_node(i): + count += 1 + return count + + def _select_min_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + + # select the min mem + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] + best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) + best_region = regions_dict[best_region_idx]["chunk_info"] + if best_region is not None: + best_region["chunk_size"] = 1 + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): + return False + return True diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py new file mode 100644 index 000000000..7a86f3c99 --- /dev/null +++ b/colossalai/autochunk/index_tracer.py @@ -0,0 +1,1056 @@ +import copy + +from .utils import ( + find_chunk_all_input_nodes, + find_chunk_compute_input_and_output_nodes, + find_idx_by_name, + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) + + +class IndexTracer(object): + def __init__(self, node_list) -> None: + self.node_list = node_list + self.idx_trace_list = self._init_idx_trace_list() + self.idx_trace_equal = [] + self.idx_view_list = {} + self.idx_count = -1 + self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} + + def _init_idx_trace_list(self): + idx_trace_list = [] + for n in self.node_list: + if get_node_shape(n) != None: + cur_trace = { + "idx": [None for _ in range(len(get_node_shape(n)))], + "compute": [[] for _ in range(len(get_node_shape(n)))], + "source": [{} for _ in range(len(get_node_shape(n)))], + } + else: + cur_trace = {"idx": [], "compute": [], "source": []} + idx_trace_list.append(cur_trace) + return idx_trace_list + + def _add_index(self): + """ + Update the count and return it. To record the idx number. + + Returns: + idx_count: int + """ + self.idx_count += 1 + return self.idx_count + + def _del_dim(self, idx, dim_idx): + self.idx_trace_list[idx]["idx"].pop(dim_idx) + self.idx_trace_list[idx]["compute"].pop(dim_idx) + self.idx_trace_list[idx]["source"].pop(dim_idx) + + def _add_dim(self, node_idx, dim_idx): + self.idx_trace_list[node_idx]["idx"].insert(dim_idx, self._add_index()) + self.idx_trace_list[node_idx]["compute"].insert(dim_idx, []) + self.idx_trace_list[node_idx]["source"].insert(dim_idx, {}) + + def _transform_index(self, node, node_dim): + node_idx = self._find_idx_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + node_to_trace["idx"][node_to_dim] = node_from_trace["idx"][node_from_dim] + node_to_trace["compute"][node_to_dim] = copy.deepcopy( + node_from_trace["compute"][node_from_dim] + ) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) + + def _inherit_all_computation(self, node_from, node_to): + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + assert len(node_from_compute) == len(node_to_compute) + for i in range(len(node_from_compute)): + self._add_source(node_from, i, node_to, i) + node_to_compute[i] = copy.deepcopy(node_from_compute[i]) + + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_from_trace_source = self._find_source_trace_from_node(node_from) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_to_trace_source = self._find_source_trace_from_node(node_to) + node_from_idx = find_idx_by_name(node_from.name, self.node_list) + if init: + node_to_trace_source[node_to_dim] = {} + # add dim to cur new source + if node_from_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] + else: + if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: + node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) + # update inputs source + for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): + if node_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) + else: + for d in node_dim: + if d not in node_to_trace_source[node_to_dim][node_idx]: + node_to_trace_source[node_to_dim][node_idx].append(d) + + def _mark_computation_from_node(self, node_from, node_to, exclude=None): + if exclude == None: + exclude = [] + else: + exclude = [self._transform_index(node_to, i) for i in exclude] + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + # assert len(node_from_compute) == len(node_to_compute) + for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): + if self._transform_index(node_to, i) in exclude: + continue + self._add_source(node_from, i, node_to, i) + for j in node_from_compute[i]: + if j not in node_to_compute[i]: + node_to_compute[i].append(j) + + def _mark_idx_equal(self, node1, dim1, node2, dim2): + """ + Mark 2 index to be equal. + + Args: + idx1 (int): index count. + idx2 (int): index count. + """ + # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) + # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) + # if node1_idx > node2_idx: + # self._add_source(node2, dim2, node1, dim1) + # else: + # self._add_source(node1, dim1, node2, dim2) + + def _mark_computation(self, node, idx, dim): + """ + Mark some dims of node as computed. + + Args: + node (node) + idx (int): node index + dim (list or int): dims to be marked as computed + """ + if isinstance(dim, int): + dim = [dim] + dims = list(range(len(get_node_shape(node)))) + for d in dim: + cur_dim = dims[d] + if idx not in self.idx_trace_list[idx]["compute"][cur_dim]: + self.idx_trace_list[idx]["compute"][cur_dim].append(idx) + + def _find_trace_from_node(self, node): + """ + Find node idx and compute trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict + + def _find_source_trace_from_node(self, node): + """ + Find node source trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + node_dict = self.idx_trace_list[node_idx] + return node_dict["source"] + + def _find_idx_trace_from_node(self, node): + """ + Find node idx trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + """ + node_idx = find_idx_by_name(node.name, self.node_list) + return self.idx_trace_list[node_idx]["idx"] + + def _find_compute_trace_from_node(self, node): + """ + Find node compute trace by the node. + + Args: + node (node) + Returns: + compute (list): computed idx of the node. + """ + node_idx = find_idx_by_name(node.name, self.node_list) + return self.idx_trace_list[node_idx]["compute"] + + def _assign_index_as_input(self, node, node_idx, input_node=None): + """ + Assign node's trace as its input node. + + Args: + node (node) + node_idx (int) + """ + if input_node == None: + input_node = node.args[0] + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"] + + new_idx_trace = copy.deepcopy(input_node_idx_trace) + self.idx_trace_list[node_idx]["idx"] = new_idx_trace + + self._inherit_all_computation(input_node, node) + + def _assign_all_index(self, node, node_idx): + """ + Add new index for all node's dims. + + Args: + node (node) + node_idx (int) + """ + shape = node.meta["tensor_meta"].shape + new_trace = [] + for _ in shape: + new_trace.append(self._add_index()) + self.idx_trace_list[node_idx]["idx"] = new_trace + + def _assign_transpose_index(self, node, node_idx): + """ + Assign index for transpose op. + 1. swap input's dim according to transpose args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + input_node = node.args[0] + tranpose_dim = node.args[1:] + + self._assign_index_as_input(node, node_idx, input_node) + self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) + self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) + + def _assign_permute_index(self, node, node_idx): + """ + Assign index for permute op. + 1. swap input's dim according to permute args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ + permute_dim = node.args[1:] + input_node = node.args[0] + + self._assign_index_as_input(node, node_idx, input_node) + for idx, d in enumerate(permute_dim): + self._inherit_index(input_node, d, node, idx) + + def _assign_linear_index(self, node, node_idx): + """ + Assign index for linear op. + 1. copy trace from input node and change last index accroding to weight + 2. mark equal for input node last index, weight first dim and bias dim. + 3. inherit input's computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + if len(node.args) == 2: + input_node, weight = node.args + bias = None + else: + input_node, weight, bias = node.args + + self._assign_index_as_input(node, node_idx) + self._inherit_index(weight, 1, node, -1) + + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(input_node, -1, weight, 0) + + if bias: + self._mark_idx_equal(input_node, -1, bias, 0) + + def _assign_matmul_index(self, node, node_idx): + """ + Assign index for matmul op. + 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) + 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. + 3. inherit matmul_left and matmul_right computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ + matmul_left, matmul_right = node.args + + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_index_as_input(node, node_idx, matmul_left) + self._inherit_index(matmul_right, -1, node, -1) + + self._mark_computation_from_node(matmul_right, node, [-1, -2]) + self._mark_computation(node, node_idx, [-1]) + self._mark_idx_equal(matmul_left, -1, matmul_right, -2) + + def _assign_layernorm_index(self, node, idx): + """ + Assign index for layernorm op. + 1. assign index as input node + 2. inherit computation and mark last 2 dims as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + self._mark_computation(node, idx, [-1]) + + def _assign_elementwise_index(self, node, idx): + """ + Assign index for element-wise op (eg. relu sigmoid add mul). + 1. assign index as input node + 2. inherit computation from all input nodes. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + nodes_in = [] + for node_in in node.args: + if type(node_in) == type(node): + nodes_in.append(node_in) + self._mark_computation_from_node(node_in, node) + assert len(nodes_in) <= 2 + if len(nodes_in) == 2: + node_in0_shape = get_node_shape(nodes_in[0]) + node_in1_shape = get_node_shape(nodes_in[1]) + for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): + if node_in0_shape[i] == node_in1_shape[i]: + self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) + + def _assgin_no_change_index(self, node, idx): + self._assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) == type(node): + self._mark_computation_from_node(node_in, node) + + def _assign_einsum_index(self, node, idx): + """ + Assign index for einsum op. + + Args: + node (node) + node_idx (int) + """ + patterns = node.args[0] + input_nodes = node.args[1:] + + patterns = patterns.replace(" ", "") + left, right = patterns.split("->") + left = left.split(",") + + all_index = [] + for i in left: + for c in i: + all_index.append(c) + all_index = set(all_index) + free_index = set([i for i in right]) + sum_index = all_index - free_index + + for right_idx, right_indice in enumerate(right): + for left_idx, left_str in enumerate(left): + if right_indice in left_str: + source_idx = left_str.index(right_indice) + self._inherit_index( + input_nodes[left_idx], source_idx, node, right_idx + ) + + # for i in sum_index: + # for left_idx, left_str in enumerate(left): + # if i in left_str: + # self._mark_computation(node, idx, left_str.index(i)) + # break + + def _assign_softmax_index(self, node, idx): + """ + Assign index for softmax op. + 1. assign index as input node + 2. inherit computation and mark softmax dim as computed. + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, idx) + self._mark_computation(node, idx, [node.kwargs["dim"]]) + + def _assign_unsqueeze_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._del_dim(node_idx, -1) + self._assign_index_as_input(node, node_idx) + self._add_dim(node_idx, node.args[1]) + + def _assign_dropout_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + + def _assign_ones_like_index(self, node, node_idx): + """ + Assign index for oneslike op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_all_index(node, node_idx) + + def _assign_view_reshape_index(self, node, node_idx): + """ + Assign index for view and reshape op. + 1. get origin shape and target shape by meta info. + 2. compute the real value of -1 in target shape. + 3. determine changed dim, and assgin index for generated dim. + 4. log changed dim and generated dim for restore + 5. inherit computation. + 6. TODO: look into view list to see whether the view is associated with other, + if so assgin equal dim according to previous view. + + Args: + node (node) + node_idx (int) + """ + # get data, turn into number + origin_node = node.args[0] + origin_shape = origin_node.meta["tensor_meta"].shape + target_shape = [] + for i in range(1, len(node.args)): + if isinstance(node.args[i], int): + target_shape.append(node.args[i]) + else: + target_shape.append(node.args[i].meta["fwd_out"][0]) + + # compute the value of -1 + if -1 in target_shape: + origin_product = 1 + for i in origin_shape: + origin_product *= i + target_product = -1 + for i in target_shape: + target_product *= i + shape_idx = target_shape.index(-1) + target_shape[shape_idx] = origin_product // target_product + + # determine changed dim + len_diff = len(origin_shape) - len(target_shape) + if len_diff == 1: + # dim merge + dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] + dim_to = [dim_equal.index(False)] + dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + self._add_dim(node_idx, -1) + elif len_diff == -1: + # dim expand + dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] + dim_from = [dim_equal.index(False)] + dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + self._del_dim(node_idx, -1) + else: + raise NotImplementedError( + "shape" + + str(origin_shape) + + "and" + + str(target_shape) + + "view not implemented" + ) + + # get new index + origin_trace = self._find_idx_trace_from_node(origin_node) + self._assign_index_as_input(node, node_idx, origin_node) + dim_from.reverse() + for i in dim_from: + self._del_dim(node_idx, i) + for i in dim_to: + self._add_dim(node_idx, i) + + # inherit computation + compute_log = self._find_compute_trace_from_node(origin_node) + for i in dim_from: + if origin_trace[i] in compute_log: + for j in dim_to: + self._mark_computation(node, node_idx, [j]) + break + + # log view, not used now + view_dict = { + "idx_from": [origin_trace[i] for i in dim_from], + "dim_from": dim_from, + "idx_to": [self.idx_trace_list[node_idx]["idx"][i] for i in dim_to], + "dim_to": dim_to, + } + self.idx_view_list[node] = view_dict + + def _merge_equal_idx(self): + idx_equal = copy.deepcopy(self.idx_trace_equal) + idx_equal.reverse() + for idx in idx_equal: + merge_to = min(idx) + merge_from = max(idx) + for trace in self.idx_trace_list: + if merge_from in trace["idx"]: + trace["idx"] = [ + merge_to if i == merge_from else i for i in trace["idx"] + ] + + def trace_index(self): + for idx, node in enumerate(self.node_list): + if node.op == "placeholder": + self._assign_all_index(node, idx) + elif node.op == "call_method": + if "transpose" in node.name: + self._assign_transpose_index(node, idx) + elif "permute" in node.name: + self._assign_permute_index(node, idx) + elif "view" in node.name or "reshape" in node.name: + self._assign_view_reshape_index(node, idx) + elif "unsqueeze" in node.name: + self._assign_unsqueeze_index(node, idx) + elif any(i in node.name for i in ["to", "contiguous"]): + self._assgin_no_change_index(node, idx) + else: + raise NotImplementedError(node.name, "method not implemented yet!") + elif node.op == "call_function": + if "linear" in node.name: + self._assign_linear_index(node, idx) + elif "matmul" in node.name: + self._assign_matmul_index(node, idx) + elif "softmax" in node.name: + self._assign_softmax_index(node, idx) + elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): + self._assign_elementwise_index(node, idx) + elif "ones_like" in node.name: + self._assign_ones_like_index(node, idx) + elif "dropout" in node.name: + self._assign_dropout_index(node, idx) + elif "einsum" in node.name: + self._assign_einsum_index(node, idx) + elif "getattr" in node.name: + continue # get attr like shape + elif "getitem" in node.name: + continue # get item in list + else: + raise NotImplementedError( + node.name, "function not implemented yet!" + ) + elif node.op == "call_module": + if any(n in node.name for n in ["layernorm", "norm"]): + self._assign_layernorm_index(node, idx) + else: + raise NotImplementedError(node.name, "module not implemented yet!") + elif node.op == "get_attr": + self._assign_all_index(node, idx) # get param + elif node.op == "output": + continue + else: + raise NotImplementedError(node.op, "op not implemented yet!") + # self._merge_equal_idx() + + def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): + """ + Check 2 given index: one index should be source of the other + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + start_node_idx = find_idx_by_name(start_node.name, self.node_list) + end_node_trace = self._find_trace_from_node(end_node) + end_node_trace_source = end_node_trace["source"][end_dim] + sorted_source = sorted( + end_node_trace_source.items(), key=lambda d: d[0], reverse=True + ) + for node_idx, node_dim in sorted_source: + if node_idx == start_node_idx and start_dim in node_dim: + return True + # it means we meet a node outside the loop, and the node is not input node + if node_idx < start_idx: + return False + return False + + def check_index_compute(self, start_idx, end_dim, end_node, end_idx): + """ + Check 2 given index: check they haven't been computed in the source trace. + Args: + start_idx(int): start node chunk dim + start_node(node): start node + end_idx(int): end node chunk dim + end_node(node): end node + + Returns: + bool: True if check pass + """ + end_node_trace = self._find_trace_from_node(end_node) + end_node_compute = end_node_trace["compute"][end_dim] + if any(start_idx <= i <= end_idx for i in end_node_compute): + return False + return True + + def get_node_chunk_dim(self, node_from, node_from_dim, node_to): + node_from_source = self._find_source_trace_from_node(node_from) + dim_source = node_from_source[node_from_dim] + node_to_idx = find_idx_by_name(node_to.name, self.node_list) + for k, v in dim_source.items(): + if k == node_to_idx: + return v + return None + + def _find_inherit_dim(self, input_node, input_dim, node): + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + if ( + input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx] + ): + return node_dim + return None + + def check_index_duplicate(self, chunk_infos, return_dim=False): + input_dim_after_node = {} + for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): + for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): + inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k]) + if inherit_dim: + input_dim_after_node[k] = inherit_dim + + for node in self.node_list[ + chunk_infos["region"][0] : chunk_infos["region"][1] + 1 + ]: + if is_non_compute_node_except_placeholder(node): + continue + count = 0 + duplicate_dims = [] + node_trace_source = self._find_source_trace_from_node(node) + for node_dim in range(len(get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False + dim_source = node_trace_source[node_dim] + for k, v in dim_source.items(): + if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + + if count > 1: + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True + + def _assgin_single_node_flow( + self, + arg_node, + start_idx, + end_idx, + cur_node_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ): + arg_idx = find_idx_by_name(arg_node.name, self.node_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: + return False + all_node_info[arg_node]["fix_dim"] = list( + set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) + ) + # else add it to list + else: + all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + cur_node_list = [self.node_list[end_idx]] # start from the last node + all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] + cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] + cur_node_idx = find_idx_by_name(cur_node.name, self.node_list) + if cur_node_chunk_dim: + cur_node_compute = self._find_compute_trace_from_node(cur_node) + cur_node_source = self._find_source_trace_from_node(cur_node) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow( + arg, + start_idx, + end_idx, + cur_node_chunk_dim, + cur_node_compute, + cur_node_source, + cur_node_fix_dim, + all_node_info, + next_node_list, + ) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not ( + start_idx + <= find_idx_by_name(arg.name, self.node_list) + < end_idx + ): + continue + arg_chunk_dim = all_node_info[arg]["chunk_dim"] + arg_fix_dim = all_node_info[arg]["fix_dim"] + arg_shape = get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + input_node_idx = find_idx_by_name(input_node.name, self.node_list) + for user in input_node.users.keys(): + if is_non_compute_node(user): + continue + user_idx = find_idx_by_name(user.name, self.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]["chunk_dim"] + if chunk_dim is not None: + user_source = self._find_source_trace_from_node(user)[chunk_dim] + if input_node_idx in user_source: + input_dict[user_idx] = user_source[input_node_idx] + else: + return None + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "node_chunk_dim": all_node_info, + "args": {}, + } + + # move useless nodes ahead of loop + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info["chunk_dim"] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort( + key=lambda x: find_idx_by_name(x.name, self.node_list), + reverse=True, + ) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + if prepose_flag == False: + break + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + if prepose_flag == False: + break + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not ( + start_idx + <= find_idx_by_name( + cur_prepose_node_arg.name, self.node_list + ) + < end_idx + ): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes + + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in prepose_nodes: + chunk_node_list.remove(n) + non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + # reassgin reshape size, some size may have changed due to chunk + chunk_info = self._reassgin_reshape_size(chunk_info) + + return chunk_info + + def _reassgin_reshape_size(self, chunk_info): + chunk_region = chunk_info["region"] + reshape_size = {} + chunk_shape = get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] + for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): + reshape_args = node.args[1:] + reshape_log = self.idx_view_list[node] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] + reshape_size[node.name] = {} + for reshape_arg_dim, reshape_arg in enumerate(reshape_args): + if reshape_arg_dim in reshape_log["dim_to"]: + continue + if reshape_arg_dim == chunk_dim: + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) + chunk_info["reshape_size"] = reshape_size + return chunk_info + + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = find_idx_by_name(n.name, self.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + new_inputs_dim = [] + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.node_list[old_idx] + self.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] + self.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info diff --git a/colossalai/autochunk/memory_estiamtor.py b/colossalai/autochunk/memory_estiamtor.py new file mode 100644 index 000000000..c3d8b1803 --- /dev/null +++ b/colossalai/autochunk/memory_estiamtor.py @@ -0,0 +1,318 @@ +import copy +from typing import Any, Callable, Dict, Iterable, List, Tuple + +import torch +from torch.fx.node import Node, map_arg + +from colossalai.fx.profiler import activation_size, parameter_size + +from .index_tracer import IndexTracer +from .utils import ( + delete_free_var_from_last_use, + find_idx_by_name, + get_node_shape, + is_non_compute_node_except_placeholder, +) + + +class MemoryEstimator(object): + def __init__(self, index_tracer: IndexTracer) -> None: + pass + + def _get_meta_node_size(self, x): + x = x.meta["tensor_meta"] + x = x.numel * torch.tensor([], dtype=x.dtype).element_size() + return x + + def _get_output_node(self, n): + fwd_out = { + x.uuid: x + for x in n.meta["fwd_out"] + if isinstance(x, torch.Tensor) and hasattr(x, "uuid") + } + out_size = activation_size(fwd_out) + out_node = [n.name] if out_size > 0 else [] + # if any(i in n.name for i in ['transpose', 'permute', 'view']): + # out_size = 0 + return out_size, out_node + + def _get_output_node_size(self, n): + return self._get_output_node(n)[0] + + def _add_active_node(self, n, active_list): + new_active = self._get_output_node(n)[1] + if n.op == "placeholder": + new_active.append(n.name) + for i in new_active: + if i not in active_list: + active_list.append(i) + + def _get_delete_node(self, user, user_to_last_uses, to_keep=None): + delete_size = 0 + delete_node = [] + if user.op not in ("output",): + nodes_to_delete = user_to_last_uses.get(user, []) + if to_keep is not None: + keep_list = [] + for n in nodes_to_delete: + if n.name in to_keep: + keep_list.append(n) + for n in keep_list: + if n in nodes_to_delete: + nodes_to_delete.remove(n) + if len(nodes_to_delete): + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + for i in range(len(out_node)): + if out_node[i][0] > 0: + delete_node.append(out_node[i][1][0]) + elif nodes_to_delete[i].op == "placeholder": + delete_node.append(nodes_to_delete[i].name) + # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): + # delete_node.append(nodes_to_delete[i].name) + return delete_size, delete_node + + def _get_delete_node_size(self, user, user_to_last_uses, to_keep): + return self._get_delete_node(user, user_to_last_uses, to_keep)[0] + + def _remove_deactive_node(self, user, user_to_last_uses, active_list): + delete_node = self._get_delete_node(user, user_to_last_uses)[1] + for i in delete_node: + if i in active_list: + active_list.remove(i) + + def _get_chunk_inputs_size( + self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx + ): + nodes_to_delete = [] + for chunk_input in chunk_inputs + chunk_inputs_non_chunk: + chunk_input_users = chunk_input.users.keys() + chunk_input_users_idx = [ + find_idx_by_name(i.name, node_list) for i in chunk_input_users + ] + if all(i <= chunk_end_idx for i in chunk_input_users_idx): + if chunk_input not in nodes_to_delete: + nodes_to_delete.append(chunk_input) + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + return delete_size + + def _get_last_usr(self, nodes): + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + return user_to_last_uses + + def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ["permute"] + inherit_contiguous_ops = ["transpose", "view"] + + if node.op == "call_function" and any( + n in node.name for n in ["matmul", "reshape"] + ): + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += self._get_output_node_size(n) + elif node.op == "call_module": + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == "call_method" and any( + i in node.name for i in not_contiguous_ops + ): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + return mem + + def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size): + if node not in chunk_node_dim: + return 1.0 + node_shape = get_node_shape(node) + chunk_dim = chunk_node_dim[node]["chunk_dim"] + if chunk_dim is None: + return 1.0 + else: + return float(chunk_size) / node_shape[chunk_dim] + + def _get_chunk_delete_node_size( + self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names + ): + # if any(j in user.name for j in ['transpose', 'permute', 'view']): + # return 0 + if user.op in ("placeholder", "output"): + return 0 + nodes_to_delete = user_to_last_uses.get(user, []) + delete_size = 0 + for n in nodes_to_delete: + if n.name in chunk_inputs_names: + continue + delete_size += self._get_output_node_size(n) * chunk_ratio + return delete_size + + def _print_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end="") + if (idx + 1) % 3 == 0: + print("") + print("\n") + + def _print_compute_op_mem_log(self, log, nodes, title=None): + if title: + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + if n.op in ["placeholder", "get_attr", "output"]: + continue + if any(i in n.name for i in ["getitem", "getattr"]): + continue + print("%s:%.2f \t" % (n.name, l), end="") + if (idx + 1) % 3 == 0: + print("") + print("\n") + + def estimate_chunk_inference_mem( + self, + node_list, + chunk_infos=None, + print_mem=False, + ): + act_memory = 0.0 + act_memory_peak_log = [] + act_memory_after_node_log = [] + active_node_list = [] + active_node_list_log = [] + not_contiguous_list = [] + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + delete_free_var_from_last_use(user_to_last_uses_no_free_var) + + use_chunk = True if chunk_infos is not None else False + chunk_within = False + chunk_region_idx = None + chunk_ratio = 1 # use it to estimate chunk mem + chunk_inputs_names = [] + + if use_chunk: + chunk_regions = [i["region"] for i in chunk_infos] + chunk_starts = [i[0] for i in chunk_regions] + chunk_ends = [i[1] for i in chunk_regions] + chunk_inputs = [i["inputs"] for i in chunk_infos] + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ + j.name for i in chunk_inputs_non_chunk for j in i + ] + chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] + chunk_sizes = [ + i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos + ] + + for idx, node in enumerate(node_list): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor + if use_chunk and idx in chunk_starts: + chunk_within = True + chunk_region_idx = chunk_starts.index(idx) + act_memory += self._get_output_node_size( + chunk_outputs[chunk_region_idx] + ) / (1024**2) + + # determine chunk ratio for current node + if chunk_within: + chunk_ratio = self._get_chunk_ratio( + node, + chunk_node_dim[chunk_region_idx], + chunk_sizes[chunk_region_idx], + ) + + # if node is placeholder, just add the size of the node + if node.op == "placeholder": + act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) + act_memory_peak_log.append(act_memory) + # skip output + elif node.op == "output": + continue + # no change for non compute node + elif is_non_compute_node_except_placeholder(node): + act_memory_peak_log.append(act_memory) + # node is a compute op + # calculate tmp, output node and delete node memory + else: + # forward memory + # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose + act_memory += ( + self._get_contiguous_memory(node, not_contiguous_list) + * chunk_ratio + / (1024**2) + ) + act_memory += ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) + # record max act memory + act_memory_peak_log.append(act_memory) + # delete useless memory + act_memory -= ( + self._get_contiguous_memory(node, not_contiguous_list, delete=True) + * chunk_ratio + / (1024**2) + ) + # delete unused vars not in chunk_input_list + # we can't delete input nodes until chunk ends + if chunk_within: + act_memory -= self._get_chunk_delete_node_size( + node, + user_to_last_uses_no_free_var, + chunk_ratio, + chunk_inputs_names, + ) / (1024**2) + else: + act_memory -= self._get_delete_node_size( + node, user_to_last_uses_no_free_var, chunk_inputs_names + ) / (1024**2) + + # log active node, only effective without chunk + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + + # if node in chunk end nodes, restore chunk settings + if use_chunk and idx in chunk_ends: + act_memory -= ( + self._get_output_node_size(node) * chunk_ratio / (1024**2) + ) + act_memory -= self._get_chunk_inputs_size( + chunk_inputs[chunk_region_idx], + chunk_inputs_non_chunk[chunk_region_idx], + node_list, + chunk_regions[chunk_region_idx][1], + ) / (1024**2) + chunk_within = False + chunk_ratio = 1 + chunk_region_idx = None + + act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) + + if print_mem: + print("with chunk" if use_chunk else "without chunk") + # self._print_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") + # self._print_compute_op_mem_log( + # act_memory_after_node_log, node_list, "after" + # ) + + # param_memory = parameter_size(gm) + # all_memory = act_memory + param_memory + return act_memory_peak_log, act_memory_after_node_log, active_node_list_log diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py new file mode 100644 index 000000000..b62a6600a --- /dev/null +++ b/colossalai/autochunk/utils.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, Iterable, List, Tuple + +from torch.fx.node import Node + + +def is_non_compute_node(node): + if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def get_node_shape(node): + if hasattr(node.meta["tensor_meta"], "shape"): + return node.meta["tensor_meta"].shape + return None + + +def is_non_compute_node_except_placeholder(node): + if any(i in node.op for i in ["get_attr", "output"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def is_non_compute_node_except_placeholder_output(node): + if any(i in node.op for i in ["get_attr"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + +def find_idx_by_name(name, nodes_list): + for idx, node in enumerate(nodes_list): + if node.name == name: + return idx + raise RuntimeError("name %s not found in node list" % name) + + +def delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == "placeholder": + user_to_last_uses[key].remove(n) + + +def find_chunk_all_input_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + for node in nodes: + for input_node in node._input_nodes.keys(): + if input_node not in nodes and input_node not in input_nodes: + input_nodes.append(input_node) + return input_nodes + + +def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): + """ + Find non-compute input and output node names. + input nodes are nodes used in the list + output nodes are nodes will use nodes in the list + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + if ( + input_node not in nodes + and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node) + ): + input_nodes.append(input_node) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + if ( + output_node not in nodes + and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node) + ): + output_nodes.append(node) + + return input_nodes, output_nodes diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 8df6d9ff4..702eb7026 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -3,7 +3,7 @@ import time import torch import torch.fx -from colossalai.autochunk.chunk_codegen import ChunkCodeGen +from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -49,25 +49,29 @@ def _build_autochunk(model, max_memory, node, pair): "pair": pair.to(torch.device("meta")), }, ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) interp.propagate( MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) + # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) interp.propagate( MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) + # set code_gen - codegen = ChunkCodeGen(gm_prop, max_memory) + codegen = AutoChunkCodeGen(gm_prop, max_memory) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() + # print - code = graph.python_code("self").src - print(code) + # code = graph.python_code("self").src + # print(code) return gm diff --git a/tests/test_autochunk/test_autochunk.py b/tests/test_autochunk/test_autochunk.py index caa2d9a80..85a162084 100644 --- a/tests/test_autochunk/test_autochunk.py +++ b/tests/test_autochunk/test_autochunk.py @@ -4,7 +4,7 @@ import torch.fx import torch.multiprocessing as mp import colossalai -from colossalai.autochunk.chunk_codegen import ChunkCodeGen +from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule @@ -82,7 +82,7 @@ def _run_offload_codegen(rank): MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) - codegen = ChunkCodeGen(gm_prop) + codegen = AutoChunkCodeGen(gm_prop) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile()