diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index ceccb9a9f..de5e7356b 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str: """ replace node name """ - patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")] for p in patterns: source = p[0] + name_from + p[1] target = p[0] + name_to + p[1] if source in context: context = context.replace(source, target) + break return context @@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) """ if node_name not in reshape_size_dict: return context - for size_name, size_value in reshape_size_dict[node_name].items(): - context = context.replace(size_name, size_value) + context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1]) return context diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index e001423f1..d38625385 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -37,10 +37,10 @@ class EstimateMemory(object): def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] - if n.op == "placeholder": + if n.op == "placeholder" and get_node_shape(n) is not None: new_active.append(n.name) for i in new_active: - if i not in active_list: + if i not in active_list and get_node_shape(n) is not None: active_list.append(i) def _get_delete_node(self, user, user_to_last_uses, to_keep=None): @@ -77,15 +77,11 @@ class EstimateMemory(object): if i in active_list: active_list.remove(i) - def _get_chunk_inputs_size( - self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx - ): + def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): nodes_to_delete = [] for chunk_input in chunk_inputs + chunk_inputs_non_chunk: chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [ - find_idx_by_name(i.name, node_list) for i in chunk_input_users - ] + chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users] if all(i <= chunk_end_idx for i in chunk_input_users_idx): if chunk_input not in nodes_to_delete: nodes_to_delete.append(chunk_input) @@ -112,9 +108,7 @@ class EstimateMemory(object): not_contiguous_ops = ["permute"] inherit_contiguous_ops = ["transpose", "view"] - if node.op == "call_function" and any( - n in node.name for n in ["matmul", "reshape"] - ): + if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]): for n in node.args: if n in not_contiguous_list: # matmul won't change origin tensor, but create a tmp copy @@ -125,9 +119,7 @@ class EstimateMemory(object): # module will just make origin tensor to contiguous if delete: not_contiguous_list.remove(n) - elif node.op == "call_method" and any( - i in node.name for i in not_contiguous_ops - ): + elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops): if node not in not_contiguous_list: not_contiguous_list.append(node) return mem @@ -142,9 +134,7 @@ class EstimateMemory(object): else: return float(chunk_size) / node_shape[chunk_dim] - def _get_chunk_delete_node_size( - self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names - ): + def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names): # if any(j in user.name for j in ['transpose', 'permute', 'view']): # return 0 if user.op in ("placeholder", "output"): @@ -196,7 +186,7 @@ class EstimateMemory(object): Returns: act_memory_peak_log (List): peak memory of every node act_memory_after_node_log (List): memory after excuting every node - active_node_list_log (List): active nodes of every node. active nodes refer to + active_node_list_log (List): active nodes of every node. active nodes refer to nodes generated but not deleted. """ act_memory = 0.0 @@ -212,7 +202,7 @@ class EstimateMemory(object): use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem + chunk_ratio = 1 # use it to estimate chunk mem chunk_inputs_names = [] if use_chunk: @@ -221,23 +211,18 @@ class EstimateMemory(object): chunk_ends = [i[1] for i in chunk_regions] chunk_inputs = [i["inputs"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] + chunk_inputs_names = [j.name for i in chunk_inputs for j in i + ] + [j.name for i in chunk_inputs_non_chunk for j in i] chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] - chunk_sizes = [ - i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos - ] + chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in chunk_starts: chunk_within = True chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size( - chunk_outputs[chunk_region_idx] - ) / (1024**2) + act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) # determine chunk ratio for current node if chunk_within: @@ -262,22 +247,13 @@ class EstimateMemory(object): else: # forward memory # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose - act_memory += ( - self._get_contiguous_memory(node, not_contiguous_list) - * chunk_ratio - / (1024**2) - ) - act_memory += ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) + act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2)) + act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2)) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= ( - self._get_contiguous_memory(node, not_contiguous_list, delete=True) - * chunk_ratio - / (1024**2) - ) + act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / + (1024**2)) # delete unused vars not in chunk_input_list # we can't delete input nodes until chunk ends if chunk_within: @@ -288,9 +264,8 @@ class EstimateMemory(object): chunk_inputs_names, ) / (1024**2) else: - act_memory -= self._get_delete_node_size( - node, user_to_last_uses_no_free_var, chunk_inputs_names - ) / (1024**2) + act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var, + chunk_inputs_names) / (1024**2) # log active node, only effective without chunk self._add_active_node(node, active_node_list) @@ -298,9 +273,7 @@ class EstimateMemory(object): # if node in chunk end nodes, restore chunk settings if use_chunk and idx in chunk_ends: - act_memory -= ( - self._get_output_node_size(node) * chunk_ratio / (1024**2) - ) + act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2)) act_memory -= self._get_chunk_inputs_size( chunk_inputs[chunk_region_idx], chunk_inputs_non_chunk[chunk_region_idx], diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index c9e5e5172..236f9697d 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import ( - get_node_shape, - is_non_compute_node, - is_non_compute_node_except_placeholder, -) +from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -73,13 +69,11 @@ class SearchChunk(object): """ free_var_idx = [] for idx, n in enumerate(self.trace_indice.node_list): - if n.op == "placeholder": + if n.op == "placeholder" and get_node_shape(n) is not None: free_var_idx.append(idx) return free_var_idx - def _search_max_chunk_region( - self, active_node: List, peak_node: Node, chunk_regions: List - ) -> Tuple: + def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple: """ Search max chunk region according to peak memory node @@ -124,15 +118,9 @@ class SearchChunk(object): region = i["region"] if chunk_region_start >= region[0] and chunk_region_end <= region[1]: return None - elif ( - region[0] <= chunk_region_start <= region[1] - and chunk_region_end > region[1] - ): + elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): chunk_region_start = region[1] + 1 - elif ( - region[0] <= chunk_region_end <= region[1] - and chunk_region_start < region[0] - ): + elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end @@ -164,25 +152,16 @@ class SearchChunk(object): for start_node, start_trace in start_traces.items(): for start_dim, _ in enumerate(start_trace["indice"]): # dim size cannot be 1 - if ( - get_node_shape(end_node)[end_dim] == 1 - or get_node_shape(start_node)[start_dim] == 1 - ): + if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): continue # check index source align - if not self.trace_flow.check_index_source( - start_dim, start_node, start_idx, end_dim, end_node - ): + if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): continue # check index copmute - if not self.trace_flow.check_index_compute( - start_idx, end_dim, end_node, end_idx - ): + if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx): continue # flow search - chunk_info = self.trace_flow.flow_search( - start_idx, start_dim, end_idx, end_dim - ) + chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) if chunk_info is None: continue # check index copmute @@ -191,9 +170,7 @@ class SearchChunk(object): chunk_infos.append(chunk_info) return chunk_infos - def _search_possible_chunk_regions( - self, max_chunk_region: Tuple, peak_node: Node - ) -> List: + def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: """ Search every possible region within the max chunk region. @@ -206,28 +183,23 @@ class SearchChunk(object): """ possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) - input_trace = [] # trace of a node's input nodes + input_trace = [] # trace of a node's input nodes for _, n in enumerate(self.trace_indice.node_list): cur_trace = {} for arg in n.args: - if type(arg) == type(n) and not is_non_compute_node_except_placeholder( - arg - ): + if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg): cur_trace[arg] = self.trace_indice._find_trace_from_node(arg) input_trace.append(cur_trace) for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes - if is_non_compute_node( - self.trace_indice.node_list[start_idx] - ) or is_non_compute_node(self.trace_indice.node_list[end_idx]): + if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node( + self.trace_indice.node_list[end_idx]): continue # select free dim - chunk_info = self._find_chunk_info( - input_trace, output_trace, start_idx, end_idx - ) + chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) if len(chunk_info) > 0: possible_chunk_region.extend(chunk_info) return possible_chunk_region @@ -256,17 +228,12 @@ class SearchChunk(object): best_chunk_region (Dict) """ peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region( - active_node, peak_node, chunk_infos - ) + max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos) if max_chunk_region == None: return None - possible_chunk_regions = self._search_possible_chunk_regions( - max_chunk_region, peak_node - ) - best_chunk_region = self.select_chunk._select_best_chunk_region( - possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) + best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node, + max_chunk_region, mem_peak) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region @@ -291,9 +258,7 @@ class SearchChunk(object): init_mem_peak, _, active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem( - self.trace_indice.node_list - ) + ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list) mem_peak = init_mem_peak while True: @@ -306,14 +271,10 @@ class SearchChunk(object): mem_peak, _, active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem( - self.trace_indice.node_list, chunk_infos - ) + ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos) if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: self.print_mem = False - self.estimate_memory.estimate_chunk_inference_mem( - self.trace_indice.node_list, chunk_infos, print_mem=True - ) + self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True) return chunk_infos diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index ec1e012be..04fa2b3bb 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -1,8 +1,13 @@ +from typing import Dict, List, Tuple + +from torch.fx.node import Node + from .trace_indice import TraceIndice from .utils import ( find_chunk_all_input_nodes, find_chunk_compute_input_and_output_nodes, find_idx_by_name, + flat_list, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder, @@ -171,7 +176,7 @@ class TraceFlow(object): # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - if cur_node_chunk_dim: + if cur_node_chunk_dim is not None: cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) else: @@ -223,15 +228,32 @@ class TraceFlow(object): cur_node_list = next_node_list return all_node_info - def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): + def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple: + """ + Get chunk dim for every input node for their every entry, remove unchunked nodes + + Args: + inputs (List[Node]): input nodes + all_node_info (Dict): describe all node's chunk dim and fix dim + start_idx (int): chunk start idx + end_idx (int): chunk end idx + + Returns: + inputs (List(Node)): new inputs + inputs_dim (List): chunk dim for inputs + """ inputs_dim = [] remove_inputs = [] for input_node in inputs: input_dict = {} input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) for user in input_node.users.keys(): + # skip non compute if is_non_compute_node(user): continue + # untraced node, mostly non compute + if user not in all_node_info: + continue user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] @@ -245,12 +267,24 @@ class TraceFlow(object): remove_inputs.append(input_node) else: inputs_dim.append(input_dict) + # remove unchunked inputs for i in remove_inputs: if i in inputs: inputs.remove(i) return inputs, inputs_dim - def _get_prepose_nodes(self, all_node_info, start_idx, end_idx): + def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]: + """ + get all useless nodes in chunk region and prepose them + + Args: + all_node_info (Dict): describe all node's chunk dim and fix dim + start_idx (int): chunk start idx + end_idx (int): chunk end idx + + Returns: + List[Node]: all nodes to be preposed + """ # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): @@ -276,7 +310,7 @@ class TraceFlow(object): for cur_prepose_node in tmp_cur_prepose_nodes: if prepose_flag == False: break - for cur_prepose_node_arg in cur_prepose_node.args: + for cur_prepose_node_arg in cur_prepose_node.all_input_nodes: if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop @@ -360,19 +394,28 @@ class TraceFlow(object): return chunk_info def _reassgin_reshape_size(self, chunk_info): + """ + Some shape args in reshape may have changed due to chunk + reassgin those changed shape + """ chunk_region = chunk_info["region"] reshape_size = {} chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): - reshape_args = node.args[1:] - reshape_log = self.trace_indice.indice_view_list[node] + reshape_args = flat_list(node.args[1:]) chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] - reshape_size[node.name] = {} + new_shape = "" for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log["dim_to"]: - continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape) + new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape + else: + if isinstance(reshape_arg, int): + new_shape += "%s, " % str(reshape_arg) + else: + new_shape += "%s, " % reshape_arg.name + new_shape = new_shape[:-2] + origin_shape = str(reshape_args)[1:-1] + reshape_size[node.name] = [origin_shape, new_shape] chunk_info["reshape_size"] = reshape_size return chunk_info diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 5a5d15e0a..862cd6b99 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -3,7 +3,7 @@ from typing import Dict, List, Tuple from torch.fx.node import Node -from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list +from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape class TraceIndice(object): @@ -28,7 +28,7 @@ class TraceIndice(object): node_list (List) """ - def __init__(self, node_list: List) -> None: + def __init__(self, node_list: List[Node]) -> None: self.node_list = node_list self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} @@ -198,7 +198,7 @@ class TraceIndice(object): node_idx = find_idx_by_name(node.name, self.node_list) return self.indice_trace_list[node_idx]["compute"] - def _assign_indice_as_input(self, node, node_idx, input_node=None): + def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None): """ Assign node's trace as its input node. @@ -216,7 +216,7 @@ class TraceIndice(object): self._inherit_all_computation(input_node, node) - def _assign_all_indice(self, node, node_idx): + def _assign_all_indice(self, node: Node, node_idx: int): """ Add new indice for all node's dims. @@ -232,7 +232,7 @@ class TraceIndice(object): new_trace.append(self._add_indice()) self.indice_trace_list[node_idx]["indice"] = new_trace - def _assign_transpose_indice(self, node, node_idx): + def _assign_transpose_indice(self, node: Node, node_idx: int): """ Assign indice for transpose op. 1. swap input's dim according to transpose args @@ -249,7 +249,7 @@ class TraceIndice(object): self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) - def _assign_permute_indice(self, node, node_idx): + def _assign_permute_indice(self, node: Node, node_idx: int): """ Assign indice for permute op. 1. swap input's dim according to permute args @@ -259,14 +259,14 @@ class TraceIndice(object): node (node) node_idx (int) """ - permute_dim = unflat_list(node.args[1:]) + permute_dim = flat_list(node.args[1:]) input_node = node.args[0] self._assign_indice_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): self._inherit_indice(input_node, d, node, idx) - def _assign_linear_indice(self, node, node_idx): + def _assign_linear_indice(self, node: Node, node_idx: int): """ Assign indice for linear op. 1. copy trace from input node and change last indice accroding to weight @@ -287,7 +287,7 @@ class TraceIndice(object): self._mark_computation(node, node_idx, [-1]) - def _assign_matmul_indice(self, node, node_idx): + def _assign_matmul_indice(self, node: Node, node_idx: int): """ Assign indice for matmul op. 1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length) @@ -393,7 +393,7 @@ class TraceIndice(object): self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [node.kwargs["dim"]]) - def _assign_unsqueeze_indice(self, node, node_idx): + def _assign_unsqueeze_indice(self, node: Node, node_idx: int): """ Assign indice for unsqueeze op. 1. assign new indice for unsqueeze dim @@ -404,9 +404,13 @@ class TraceIndice(object): """ self._del_dim(node_idx, -1) self._assign_indice_as_input(node, node_idx) - self._add_dim(node_idx, node.args[1]) + dim_idx = node.args[1] + # unsqueeze(-1) = unsqueeze(shape_num + 1) + if dim_idx < 0: + dim_idx = list(range(len(get_node_shape(node))))[dim_idx] + self._add_dim(node_idx, dim_idx) - def _assign_dropout_indice(self, node, node_idx): + def _assign_dropout_indice(self, node: Node, node_idx: int): """ Assign indice for unsqueeze op. 1. assign new indice for unsqueeze dim @@ -417,7 +421,7 @@ class TraceIndice(object): """ self._assign_indice_as_input(node, node_idx) - def _assign_ones_like_indice(self, node, node_idx): + def _assign_ones_like_indice(self, node: Node, node_idx: int): """ Assign indice for oneslike op. 1. assign new indice for all dim @@ -428,7 +432,47 @@ class TraceIndice(object): """ self._assign_all_indice(node, node_idx) - def _assign_view_reshape_indice(self, node, node_idx): + def _assign_getitem_indice(self, node: Node, node_idx: int): + """ + Assign indice for getitem. + getitem can act like slice sometimes + + Args: + node (node) + node_idx (int) + """ + node_args = flat_list(node.args[1:]) + if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args): + return + + # node args should be like [Ellipsis, slice(start, step, end), None] + node_shape = get_node_shape(node) + origin_idx_count = 0 + new_idx_count = 0 + new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args]) + for _ in range(new_dim_num): + self._del_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx) + + for _, node_arg in enumerate(node_args): + node_arg_str = str(node_arg) + # Ellipsis means [..., ] + if "Ellipsis" == node_arg_str: + shape_gap = len(node_shape) - len(node_args) + 1 + origin_idx_count += shape_gap + new_idx_count += shape_gap + # slice(None, None, None) means all indexes, doesn't support other slice + elif "slice(None, None, None)" == node_arg_str: + origin_idx_count += 1 + new_idx_count += 1 + # None means a new dim + elif "None" == node_arg_str: + self._add_dim(node_idx, new_idx_count) + new_idx_count += 1 + else: + raise NotImplementedError() + + def _assign_view_reshape_indice(self, node: Node, node_idx: int): """ Assign indice for view and reshape op. 1. get origin shape and target shape by meta info. @@ -447,7 +491,7 @@ class TraceIndice(object): origin_node = node.args[0] origin_shape = origin_node.meta["tensor_meta"].shape target_shape = [] - unflated_args = unflat_list(node.args) + unflated_args = flat_list(node.args) for i in range(1, len(unflated_args)): if isinstance(unflated_args[i], int): target_shape.append(unflated_args[i]) @@ -544,6 +588,8 @@ class TraceIndice(object): self._assign_einsum_indice(node, idx) elif "layer_norm" in node.name: self._assign_layernorm_indice(node, idx) + elif "getitem" in node.name: + self._assign_getitem_indice(node, idx) elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]): continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 5f3ea3bf4..9c2363b54 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple from torch.fx.node import Node -def unflat_list(inputs): +def flat_list(inputs): """ - unflat a list by recursion + flat a list by recursion """ res = [] for i in inputs: if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): - res.extend(unflat_list(i)) + res.extend(flat_list(i)) else: res.append(i) return res @@ -27,8 +27,13 @@ def find_first_tensor_arg(node): def is_non_compute_node(node): - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"]): + if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]): + return True + if "getitem" in node.name: + node_args = flat_list(node.args[1:]) + for node_arg in node_args: + if any(i == str(node_arg) for i in ["None", "Ellipsis"]): + return False return True return False @@ -40,15 +45,15 @@ def get_node_shape(node): def is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]): - return True - return False + if "placeholder" in node.op: + return False + return is_non_compute_node(node) def is_non_compute_node_except_placeholder_output(node): - if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]): - return True - return False + if "output" in node.op: + return False + return is_non_compute_node_except_placeholder(node) def find_idx_by_name(name, nodes_list): diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py index 1273bf2fe..c5a893eda 100644 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta(): def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): # for memory test + # model = model.cuda() # torch.cuda.reset_peak_memory_stats() # now_mem = torch.cuda.memory_allocated() / 1024**2 # with torch.no_grad(): # node1 = node.clone() # pair1 = pair.clone() - # gm(node1, pair1) - # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # node_mask1 = node_mask.clone() + # pair_mask1 = pair_mask.clone() + # gm(node1, pair1, node_mask1, pair_mask1) # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print( - # "autochunk now mem:%.2f max mem:%.2f" - # % (new_now_mem - now_mem, new_max_mem - now_mem) - # ) + # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) # test forward model = model.cuda() @@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), ) - # codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) # trace and recompile # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer @@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): "_mask_trans": True, }, ) - # graph.set_codegen(codegen) + graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() # assert we have inserted chunk code = graph.python_code("self").src - assert "chunk_size" in code # print(code) + assert "chunk_result = None; chunk_size = None;" in code _test_fwd(model, gm, node, pair, node_mask, pair_mask) gpc.destroy() @@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) +@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) def test_evoformer_codegen(msa_len, pair_len, max_memory): @@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_evoformer_codegen(0, 32, 64, 25) + _test_evoformer_codegen(0, 32, 64, 24) diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py index f1272330f..8ab77024c 100644 --- a/tests/test_autochunk/test_simple_evoformer_codegen.py +++ b/tests/test_autochunk/test_simple_evoformer_codegen.py @@ -13,7 +13,7 @@ except: import colossalai from colossalai.core import global_context as gpc -from colossalai.fx import ColoTracer +from colossalai.fx import ColoTracer, symbolic_trace from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule @@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta(): def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - # for memory test - # torch.cuda.reset_peak_memory_stats() - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # with torch.no_grad(): - # node1 = node.clone() - # pair1 = pair.clone() - # gm(node1, pair1) - # new_now_mem = torch.cuda.memory_allocated() / 1024**2 - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print( - # "autochunk now mem:%.2f max mem:%.2f" - # % (new_now_mem - now_mem, new_max_mem - now_mem) - # ) - - # test forward with torch.no_grad(): non_fx_out = model(node, pair) fx_out = gm(node, pair) @@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() + # meta info prop + meta_graph = symbolic_trace(model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }) # must use symbolic_trace + interp = MetaInfoProp(meta_graph) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) + # trace the module and replace codegen graph = ColoTracer().trace( model, @@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): "pair": pair.to(torch.device("meta")), }, ) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace - interp = MetaInfoProp(gm_prop) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - # now run it twice to get meta info in graph module, not necessary - gm = torch.fx.GraphModule(model, graph) - interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() # assert we have inserted chunk code = graph.python_code("self").src - assert "chunk_size" in code # print(code) + assert "chunk_result = None; chunk_size = None;" in code _test_fwd(model, gm, node, pair) gpc.destroy() diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py index 04fb514fb..4c591c483 100644 --- a/tests/test_autochunk/test_simple_evoformer_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): str(target_regions), ) for region in target_regions: - assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( + assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % ( str(region), msa_len, pair_len, - max_memory, + str(max_memory), ) for region in found_regions: assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( str(region), msa_len, pair_len, - max_memory, + str(max_memory), )