From 63199c668792cf86f24e8583363e8625154ed9d5 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Tue, 31 Jan 2023 16:00:06 +0800 Subject: [PATCH] [autochunk] support transformer (#2526) --- colossalai/autochunk/autochunk_codegen.py | 7 +- colossalai/autochunk/search_chunk.py | 66 +--- colossalai/autochunk/select_chunk.py | 124 +++--- colossalai/autochunk/trace_flow.py | 112 +++--- colossalai/autochunk/trace_indice.py | 373 ++++++++++++------ colossalai/autochunk/utils.py | 50 ++- .../benchmark_simple_evoformer.py | 94 ----- .../test_alphafold/test_alphafold_utils.py | 122 ++++++ .../test_alphafold/test_evoformer_block.py | 95 +++++ .../test_alphafold/test_evoformer_stack.py | 90 +++++ .../test_alphafold/test_extramsa_block.py | 96 +++++ .../test_diffuser/test_diffuser_utils.py | 120 ++++++ .../test_autochunk/test_diffuser/test_unet.py | 70 ++++ .../test_autochunk/test_evoformer_codegen.py | 163 -------- .../test_evoformer_stack_codegen.py | 163 -------- tests/test_autochunk/test_extramsa_codegen.py | 164 -------- .../test_simple_evoformer_codegen.py | 104 ----- .../test_simple_evoformer_search.py | 97 ----- .../test_transformer/test_autochunk_gpt.py | 65 +++ .../test_transformer_utils.py | 123 ++++++ 20 files changed, 1214 insertions(+), 1084 deletions(-) delete mode 100644 tests/test_autochunk/benchmark_simple_evoformer.py create mode 100644 tests/test_autochunk/test_alphafold/test_alphafold_utils.py create mode 100644 tests/test_autochunk/test_alphafold/test_evoformer_block.py create mode 100644 tests/test_autochunk/test_alphafold/test_evoformer_stack.py create mode 100644 tests/test_autochunk/test_alphafold/test_extramsa_block.py create mode 100644 tests/test_autochunk/test_diffuser/test_diffuser_utils.py create mode 100644 tests/test_autochunk/test_diffuser/test_unet.py delete mode 100644 tests/test_autochunk/test_evoformer_codegen.py delete mode 100644 tests/test_autochunk/test_evoformer_stack_codegen.py delete mode 100644 tests/test_autochunk/test_extramsa_codegen.py delete mode 100644 tests/test_autochunk/test_simple_evoformer_codegen.py delete mode 100644 tests/test_autochunk/test_simple_evoformer_search.py create mode 100644 tests/test_autochunk/test_transformer/test_autochunk_gpt.py create mode 100644 tests/test_autochunk/test_transformer/test_transformer_utils.py diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 8c3155a60..ddf64dc8f 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple import torch import colossalai +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -if CODEGEN_AVAILABLE: +AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() + +if AUTOCHUNK_AVAILABLE: from torch.fx.graph import ( CodeGen, PythonCode, @@ -272,7 +275,7 @@ def emit_code_with_chunk( node_idx += 1 -if CODEGEN_AVAILABLE: +if AUTOCHUNK_AVAILABLE: class AutoChunkCodeGen(CodeGen): diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index a86196712..720f3d925 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import ( + find_chunk_compute_input_and_output_nodes, + get_logger, + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) class SearchChunk(object): @@ -114,6 +120,12 @@ class SearchChunk(object): chunk_region_start (int) chunk_region_end (int) """ + # check if peak node already in chunkinfo + if chunk_regions is not None: + for i in chunk_regions: + if i["region"][0] < peak_node_idx <= i["region"][1]: + return None + free_vars = self._get_free_var_idx() free_var_num = len(free_vars) active_node_num = [len(i) for i in active_node] @@ -152,55 +164,6 @@ class SearchChunk(object): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end - def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: - """ - Find chunk info for a region. - - We are given the region start and region end, and need to find out all chunk info for it. - We first loop every dim of start node and end node, to see if we can find dim pair, - which is linked in a flow and not computed. - If found, we then search flow in the whole region to find out all chunk infos. - - Args: - input_trace (List): node's input trace in region - output_trace (List): node's output trace in region - start_idx (int): region start node index - end_idx (int): region end node index - - Returns: - chunk_infos: possible regions found - """ - start_traces = input_trace[start_idx] - end_trace = output_trace[end_idx] - end_node = self.trace_indice.node_list[end_idx] - chunk_infos = [] - for end_dim, _ in enumerate(end_trace["indice"]): - if len(start_traces) > 1: - continue - for start_node, start_trace in start_traces.items(): - for start_dim, _ in enumerate(start_trace["indice"]): - # dim size cannot be 1 - if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): - continue - # must have users - if len(end_node.users) == 0: - continue - # check index source align - if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): - continue - # check index copmute - if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx): - continue - # flow search - chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) - if chunk_info is None: - continue - # check index copmute - if not self.trace_flow.check_index_duplicate(chunk_info): - continue - chunk_infos.append(chunk_info) - return chunk_infos - def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: """ Search every possible region within the max chunk region. @@ -228,9 +191,8 @@ class SearchChunk(object): if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node( self.trace_indice.node_list[end_idx]): continue - # select free dim - chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) + chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx) if len(chunk_info) > 0: possible_chunk_region.extend(chunk_info) return possible_chunk_region diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index f0612e45a..1f3a95727 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -5,6 +5,7 @@ from .utils import is_non_compute_node class SelectChunk(object): + def __init__( self, trace_indice: TraceIndice, @@ -17,13 +18,11 @@ class SelectChunk(object): self.reorder_graph = reorder_graph if max_memory is not None: self.stratge = "fit_memory" - self.max_memory = max_memory # MB + self.max_memory = max_memory # MB else: self.stratge = "min_memory" - def _select_best_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): + def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak): if self.stratge == "min_memory": best_region = self._select_min_memory_chunk_region( possible_chunk_regions, @@ -44,9 +43,8 @@ class SelectChunk(object): raise RuntimeError() return best_region - def _select_fit_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): + def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, + mem_peak): # stop chunk if max memory satisfy memory limit if max(mem_peak) < self.max_memory: return None @@ -63,33 +61,26 @@ class SelectChunk(object): if len(possible_chunk_regions) == 0: return None + max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions])) + # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.reorder_graph.tmp_reorder( - self.trace_indice.node_list, cur_region - ) + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) + regions_dict.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + }) # no region found if len(regions_dict) == 0: raise RuntimeError("Search failed. Try a larger memory threshold.") @@ -113,20 +104,13 @@ class SelectChunk(object): chunk_size *= 2 reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[ - reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] - + 1 - ] - ) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], + cur_chunk_infos)[0] + cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) # search exact size chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search( - chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos - ) + chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, + chunk_infos) return chunk_info def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): @@ -139,12 +123,9 @@ class SelectChunk(object): mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( - chunk_region_dict["reorder_node_list"], cur_chunk_infos - )[0] - cur_chunk_max_mem = max( - cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] - ) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], + cur_chunk_infos)[0] + cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) if cur_chunk_max_mem >= self.max_memory: right = mid - gap else: @@ -153,14 +134,13 @@ class SelectChunk(object): def _get_compute_node_num(self, start, end): count = 0 - for i in self.trace_indice.node_list[start : end + 1]: + for i in self.trace_indice.node_list[start:end + 1]: if not is_non_compute_node(i): count += 1 return count - def _select_min_memory_chunk_region( - self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak - ): + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, + mem_peak): # remove illegal regions illegal_regions = [] for i in possible_chunk_regions: @@ -173,37 +153,31 @@ class SelectChunk(object): if len(possible_chunk_regions) == 0: return None + # get max possible chunk region + max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions])) + # get mem for chunk region - regions_dict = [] + regions_dict_list = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.reorder_graph.tmp_reorder( - self.trace_indice.node_list, cur_region - ) + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos - )[0] - cur_chunk_region_peak = cur_mem_peak[ - max_chunk_region[0] : max_chunk_region[1] + 1 - ] + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict.append( - { - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num( - region["region"][0], region["region"][1] - ), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - } - ) + regions_dict_list.append({ + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + }) # select the min mem - chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) - best_region = regions_dict[best_region_idx]["chunk_info"] + best_region = regions_dict_list[best_region_idx]["chunk_info"] if best_region is not None: best_region["chunk_size"] = 1 return best_region @@ -216,9 +190,7 @@ class SelectChunk(object): return False for i in chunk_infos: region = i["region"] - if not ( - (chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0]) - ): + if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or + (chunk_region_start < region[0] and chunk_region_end < region[0])): return False return True diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 830b4629e..df7343764 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -8,9 +8,9 @@ from .utils import ( find_chunk_compute_input_and_output_nodes, find_idx_by_name, flat_list, + get_node_name, get_node_shape, is_non_compute_node, - is_non_compute_node_except_placeholder, ) @@ -79,43 +79,6 @@ class TraceFlow(object): return node_dim return None - def check_index_duplicate(self, chunk_infos, return_dim=False): - input_dim_after_node = {} - for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): - for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k]) - if inherit_dim: - input_dim_after_node[k] = inherit_dim - - for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]: - if is_non_compute_node_except_placeholder(node): - continue - count = 0 - duplicate_dims = [] - node_trace_source = self.trace_indice._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - duplicate_dim = [] - duplicate_flag = False - dim_source = node_trace_source[node_dim] - for k, v in dim_source.items(): - if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] in v: - duplicate_flag = True - duplicate_dim.append((k, v)) - duplicate_dims.append(duplicate_dim) - if duplicate_flag: - count += 1 - - if count > 1: - if return_dim: - return False, duplicate_dims - else: - return False - if return_dim: - return True, None - else: - return True - def _assgin_single_node_flow( self, arg_node: Node, @@ -225,9 +188,12 @@ class TraceFlow(object): if flow_flag == False: return None - if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul", "truediv"]): + if len(arg_list) >= 2: + # need to mark fix dim + if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]): for arg in arg_list: + if get_node_shape(arg) is None: + continue if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): continue arg_chunk_dim = all_node_info[arg]["chunk_dim"] @@ -240,9 +206,8 @@ class TraceFlow(object): return None if i not in arg_fix_dim: arg_fix_dim.append(i) - elif "einsum" in cur_node.name: - pass - elif "matmul" in cur_node.name: + elif any(i == get_node_name(cur_node) + for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]): pass else: raise NotImplementedError() @@ -426,7 +391,7 @@ class TraceFlow(object): reshape_size = {} chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: - if any(i in node.name for i in ["reshape", "view"]): + if any(i == get_node_name(node) for i in ["reshape", "view"]): reshape_args = flat_list(node.args[1:]) chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" @@ -443,3 +408,62 @@ class TraceFlow(object): reshape_size[node.name] = [origin_shape, new_shape] chunk_info["reshape_size"] = reshape_size return chunk_info + + def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: + """ + Find chunk info for a region. + + We are given the region start and region end, and need to find out all chunk info for it. + We first loop every dim of start node and end node, to see if we can find dim pair, + which is linked in a flow and not computed. + If found, we then search flow in the whole region to find out all chunk infos. + + Args: + input_trace (List): node's input trace in region + output_trace (List): node's output trace in region + start_idx (int): region start node index + end_idx (int): region end node index + + Returns: + chunk_infos: possible regions found + """ + start_traces = input_trace[start_idx] + if len(start_traces) > 1: # TODO need to be removed + return [] + end_trace = output_trace[end_idx] + end_node = self.trace_indice.node_list[end_idx] + + chunk_infos = [] + for end_dim, _ in enumerate(end_trace["indice"]): + for start_node, start_trace in start_traces.items(): + for start_dim, _ in enumerate(start_trace["indice"]): + if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx): + continue + # flow search + chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim) + if chunk_info is None: + continue + chunk_infos.append(chunk_info) + return chunk_infos + + def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, + end_idx: int) -> bool: + """ + check if region start and end is legal + """ + # dim cannot be None + if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): + return False + # dim size cannot be 1 + if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): + return False + # must have users + if len(end_node.users) == 0: + return False + # check index source align + if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): + return False + # check index copmute + if not self.check_index_compute(start_idx, end_dim, end_node, end_idx): + return False + return True diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 827f60d8b..8f517cf2c 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -3,7 +3,14 @@ from typing import Dict, List, Tuple from torch.fx.node import Node -from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape +from .utils import ( + find_first_tensor_arg, + find_idx_by_name, + flat_list, + get_module_node_name, + get_node_name, + get_node_shape, +) class TraceIndice(object): @@ -36,7 +43,7 @@ class TraceIndice(object): self.trace_range = [] self.active_node_list = [] - def _init_indice_trace_list(self): + def _init_indice_trace_list(self) -> List: indice_trace_list = [] for n in self.node_list: if get_node_shape(n) != None: @@ -54,7 +61,7 @@ class TraceIndice(object): self.trace_range = trace_range self.active_node_list = active_node_list - def _add_indice(self): + def _add_indice(self) -> int: """ Update the count and return it. To record the idx number. @@ -64,39 +71,30 @@ class TraceIndice(object): self.indice_count += 1 return self.indice_count - def _del_dim(self, idx, dim_idx): + def _del_dim(self, idx: int, dim_idx: int) -> None: + """ + delete a dim for indice, compute and source + """ self.indice_trace_list[idx]["indice"].pop(dim_idx) self.indice_trace_list[idx]["compute"].pop(dim_idx) self.indice_trace_list[idx]["source"].pop(dim_idx) - def _add_dim(self, node_idx, dim_idx): + def _add_dim(self, node_idx: int, dim_idx: int) -> None: + """ + add a dim for indice, compute and source + """ self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) - def _transform_indice(self, node, node_dim): - node_idx = self._find_indice_trace_from_node(node) - dims = list(range(len(node_idx))) - return dims[node_dim] - - def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim): - node_from_dim = self._transform_indice(node_from, node_from_dim) - node_to_dim = self._transform_indice(node_to, node_to_dim) - node_from_trace = self._find_trace_from_node(node_from) - node_to_trace = self._find_trace_from_node(node_to) - node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] - node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) - self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) - - def _inherit_all_computation(self, node_from, node_to): - node_from_compute = self._find_compute_trace_from_node(node_from) - node_to_compute = self._find_compute_trace_from_node(node_to) - assert len(node_from_compute) == len(node_to_compute) - for i in range(len(node_from_compute)): - self._add_source(node_from, i, node_to, i) - node_to_compute[i] = copy.deepcopy(node_from_compute[i]) - - def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): + def _add_source( + self, + node_from: Node, + node_from_dim: int, + node_to: Node, + node_to_dim: int, + init=False, + ) -> None: node_from_dim = self._transform_indice(node_from, node_from_dim) node_from_trace_source = self._find_source_trace_from_node(node_from) node_to_dim = self._transform_indice(node_to, node_to_dim) @@ -119,7 +117,50 @@ class TraceIndice(object): if d not in node_to_trace_source[node_to_dim][node_idx]: node_to_trace_source[node_to_dim][node_idx].append(d) - def _mark_computation_from_node(self, node_from, node_to, exclude=None): + def _transform_indice(self, node: Node, node_dim: int) -> int: + node_idx = self._find_indice_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_indice( + self, + node_from: Node, + node_from_dim: int, + node_to: Node, + node_to_dim: int, + init: bool = True, + ) -> None: + """ + node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source + """ + node_from_dim = self._transform_indice(node_from, node_from_dim) + node_to_dim = self._transform_indice(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + if init: + node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] + node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) + else: + for j in node_from_trace["compute"][node_from_dim]: + if j not in node_to_trace["compute"][node_to_dim]: + node_to_trace["compute"][node_to_dim].append(j) + self._add_source(node_from, node_from_dim, node_to, node_to_dim, init) + + def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None: + """ + inherit all dims with init + """ + # find indice just for assert length + node_from_indice = self._find_indice_trace_from_node(node_from) + node_to_indice = self._find_indice_trace_from_node(node_to) + assert len(node_from_indice) == len(node_to_indice) + for i in range(len(node_from_indice)): + self._inherit_indice(node_from, i, node_to, i, init=True) + + def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None: + """ + inheirt indice from node without init + """ if exclude == None: exclude = [] else: @@ -130,12 +171,9 @@ class TraceIndice(object): for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): if self._transform_indice(node_to, i) in exclude: continue - self._add_source(node_from, i, node_to, i) - for j in node_from_compute[i]: - if j not in node_to_compute[i]: - node_to_compute[i].append(j) + self._inherit_indice(node_from, i, node_to, i, init=False) - def _mark_computation(self, node, idx, dim): + def _mark_computation(self, node: Node, idx: int, dim: int) -> None: """ Mark some dims of node as computed. @@ -152,7 +190,7 @@ class TraceIndice(object): if idx not in self.indice_trace_list[idx]["compute"][cur_dim]: self.indice_trace_list[idx]["compute"][cur_dim].append(idx) - def _find_trace_from_node(self, node): + def _find_trace_from_node(self, node: Node) -> Dict: """ Find node idx and compute trace by the node. @@ -166,7 +204,7 @@ class TraceIndice(object): node_dict = self.indice_trace_list[node_idx] return node_dict - def _find_source_trace_from_node(self, node): + def _find_source_trace_from_node(self, node: Node) -> List: """ Find node source trace by the node. @@ -180,7 +218,7 @@ class TraceIndice(object): node_dict = self.indice_trace_list[node_idx] return node_dict["source"] - def _find_indice_trace_from_node(self, node): + def _find_indice_trace_from_node(self, node) -> List: """ Find node idx trace by the node. @@ -192,7 +230,7 @@ class TraceIndice(object): node_idx = find_idx_by_name(node.name, self.node_list) return self.indice_trace_list[node_idx]["indice"] - def _find_compute_trace_from_node(self, node): + def _find_compute_trace_from_node(self, node: Node) -> List: """ Find node compute trace by the node. @@ -204,7 +242,7 @@ class TraceIndice(object): node_idx = find_idx_by_name(node.name, self.node_list) return self.indice_trace_list[node_idx]["compute"] - def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None): + def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None: """ Assign node's trace as its input node. @@ -214,15 +252,9 @@ class TraceIndice(object): """ if input_node == None: input_node = find_first_tensor_arg(node) - input_node_idx = find_idx_by_name(input_node.name, self.node_list) - input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] + self._inherit_all_indice(input_node, node) - new_idx_trace = copy.deepcopy(input_node_idx_trace) - self.indice_trace_list[node_idx]["indice"] = new_idx_trace - - self._inherit_all_computation(input_node, node) - - def _assign_all_indice(self, node: Node, node_idx: int): + def _assign_all_indice(self, node: Node, node_idx: int) -> None: """ Add new indice for all node's dims. @@ -238,7 +270,7 @@ class TraceIndice(object): new_trace.append(self._add_indice()) self.indice_trace_list[node_idx]["indice"] = new_trace - def _assign_transpose_indice(self, node: Node, node_idx: int): + def _assign_transpose_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for transpose op. 1. swap input's dim according to transpose args @@ -255,7 +287,7 @@ class TraceIndice(object): self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0]) self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1]) - def _assign_permute_indice(self, node: Node, node_idx: int): + def _assign_permute_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for permute op. 1. swap input's dim according to permute args @@ -272,7 +304,7 @@ class TraceIndice(object): for idx, d in enumerate(permute_dim): self._inherit_indice(input_node, d, node, idx) - def _assign_linear_indice(self, node: Node, node_idx: int): + def _assign_linear_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for linear op. 1. copy trace from input node and change last indice accroding to weight @@ -293,7 +325,23 @@ class TraceIndice(object): self._mark_computation(node, node_idx, [-1]) - def _assign_matmul_indice(self, node: Node, node_idx: int): + def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for addmm op. + + Args: + node (node) + node_idx (int) + """ + bias, input_node, weight = node.args + + self._assign_indice_as_input(node, node_idx, input_node) + self._inherit_indice(weight, 1, node, -1) + self._inherit_indice(bias, -1, node, -1) + + self._mark_computation(node, node_idx, [-1]) + + def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for matmul op. 1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length) @@ -310,7 +358,7 @@ class TraceIndice(object): self._assign_indice_as_input(node, node_idx, matmul_left) self._inherit_indice(matmul_right, -1, node, -1) - self._mark_computation_from_node(matmul_right, node, [-1, -2]) + self._inherit_more_indice_from_node(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) def _assign_layernorm_indice(self, node, idx): @@ -341,14 +389,13 @@ class TraceIndice(object): for node_in in node.args: if type(node_in) == type(node): nodes_in.append(node_in) - self._mark_computation_from_node(node_in, node) - assert len(nodes_in) <= 2 + self._inherit_more_indice_from_node(node_in, node) def _assgin_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): - self._mark_computation_from_node(node_in, node) + self._inherit_more_indice_from_node(node_in, node) def _assign_einsum_indice(self, node, idx): """ @@ -365,7 +412,7 @@ class TraceIndice(object): left, right = patterns.split("->") left = left.split(",") - if '...' in right: + if "..." in right: replace_list = "!@#$%^&*" target_len = len(get_node_shape(node)) add_len = target_len - len(right) + 3 @@ -399,7 +446,22 @@ class TraceIndice(object): self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [node.kwargs["dim"]]) - def _assign_unsqueeze_indice(self, node: Node, node_idx: int): + def _assign_split_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for split op. + + Args: + node (node) + node_idx (int) + """ + for _ in range(len(get_node_shape(node.args[0]))): + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx) + dim_idx = node.kwargs["dim"] + self._del_dim(node_idx, dim_idx) + self._add_dim(node_idx, dim_idx) + + def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for unsqueeze op. 1. assign new indice for unsqueeze dim @@ -416,18 +478,7 @@ class TraceIndice(object): dim_idx = list(range(len(get_node_shape(node))))[dim_idx] self._add_dim(node_idx, dim_idx) - def _assign_dropout_indice(self, node: Node, node_idx: int): - """ - Assign indice for unsqueeze op. - 1. assign new indice for unsqueeze dim - - Args: - node (node) - node_idx (int) - """ - self._assign_indice_as_input(node, node_idx) - - def _assign_ones_like_indice(self, node: Node, node_idx: int): + def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for oneslike op. 1. assign new indice for all dim @@ -438,7 +489,7 @@ class TraceIndice(object): """ self._assign_all_indice(node, node_idx) - def _assign_cat_indice(self, node: Node, node_idx: int): + def _assign_cat_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for cat op. @@ -449,12 +500,12 @@ class TraceIndice(object): nodes_in = flat_list(node.args[0]) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._mark_computation_from_node(n, node) + self._inherit_more_indice_from_node(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) self._add_dim(node_idx, cat_dim) - def _assign_sum_indice(self, node: Node, node_idx: int): + def _assign_sum_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for sum op. @@ -466,11 +517,46 @@ class TraceIndice(object): self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._mark_computation_from_node(n, node) + self._inherit_more_indice_from_node(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) - def _assign_getitem_indice(self, node: Node, node_idx: int): + def _assign_arange_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for arange op. + + Args: + node (node) + node_idx (int) + """ + self._assign_all_indice(node, node_idx) + + def _assign_tensor_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for tensor op. + + Args: + node (node) + node_idx (int) + """ + if len(get_node_shape(node)) == 0: + return + else: + raise NotImplementedError() + + def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for embedding op. + + Args: + node (node) + node_idx (int) + """ + self._del_dim(node_idx, -1) + self._assign_indice_as_input(node, node_idx) + self._add_dim(node_idx, -1) + + def _assign_getitem_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for getitem. getitem can act like slice sometimes @@ -480,6 +566,19 @@ class TraceIndice(object): node_idx (int) """ node_args = flat_list(node.args[1:]) + + # deal with split + if get_node_name(node.args[0]) == "split": + self._assign_indice_as_input(node, node_idx) + self._del_dim(node_idx, node.args[0].kwargs["dim"]) + self._add_dim(node_idx, node.args[0].kwargs["dim"]) + return + + # skip non tensor + if get_node_shape(node) is None: + return + + # find if slice flag = False for node_arg in node_args: node_arg_str = str(node_arg) @@ -528,7 +627,7 @@ class TraceIndice(object): else: raise NotImplementedError() - def _assign_view_reshape_indice(self, node: Node, node_idx: int): + def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for view and reshape op. 1. get origin shape and target shape by meta info. @@ -536,7 +635,7 @@ class TraceIndice(object): 3. determine changed dim, and assgin indice for generated dim. 4. log changed dim and generated dim for restore 5. inherit computation. - 6. TODO: look into view list to see whether the view is associated with other, + 6. look into view list to see whether the view is associated with other, if so assgin equal dim according to previous view. Args: @@ -552,7 +651,7 @@ class TraceIndice(object): if isinstance(unflated_args[i], int): target_shape.append(unflated_args[i]) else: - target_shape.append(unflated_args[i].meta["fwd_out"][0]) + target_shape.extend(unflated_args[i].meta["fwd_out"]) # compute the value of -1 if -1 in target_shape: @@ -579,17 +678,36 @@ class TraceIndice(object): dim_from = [dim_equal.index(False)] dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] self._del_dim(node_idx, -1) + elif len_diff == 0: + # dim equal + dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] + dim_from = [] + dim_to = [] else: raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented") # get new indice origin_trace = self._find_indice_trace_from_node(origin_node) self._assign_indice_as_input(node, node_idx, origin_node) + idx_from = [origin_trace[i] for i in dim_from] dim_from.reverse() for i in dim_from: self._del_dim(node_idx, i) for i in dim_to: self._add_dim(node_idx, i) + dim_from.reverse() + + # search view list + for view_node, view_dict in self.indice_view_list.items(): + if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from + and view_dict["dim_from"] == dim_to): + # inheirt indice from current node + for dim_to_i in dim_to: + for dim_from_i in dim_from: + self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False) + # inherid indice from input node of last view + for dim_to_i in dim_to: + self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) # inherit computation compute_log = self._find_compute_trace_from_node(origin_node) @@ -630,7 +748,7 @@ class TraceIndice(object): # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): - if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes: + if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes): dim_compute.pop(i) continue # clear source @@ -639,59 +757,82 @@ class TraceIndice(object): if k < trace_range[0] and k not in active_nodes: dim_source.pop(k) - def trace_indice(self): + def trace_indice(self) -> None: for idx, node in enumerate(self.node_list): + node_name = get_node_name(node) if node.op == "placeholder": self._assign_all_indice(node, idx) elif node.op == "call_method": - if "transpose" in node.name: + if "transpose" == node_name: self._assign_transpose_indice(node, idx) - elif "permute" in node.name: + elif "permute" == node_name: self._assign_permute_indice(node, idx) - elif "view" in node.name or "reshape" in node.name: + elif "view" == node_name or "reshape" == node_name: self._assign_view_reshape_indice(node, idx) - elif "unsqueeze" in node.name: + elif "unsqueeze" == node_name: self._assign_unsqueeze_indice(node, idx) - elif any(i in node.name for i in ["to", "contiguous", "clone"]): + elif "split" == node_name: + self._assign_split_indice(node, idx) + elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]): self._assgin_no_change_indice(node, idx) - elif "new_ones" in node.name: + elif "new_ones" == node_name: self._assign_ones_like_indice(node, idx) - else: - raise NotImplementedError(node.name, "method not implemented yet!") - elif node.op == "call_function": - if "linear" in node.name: - self._assign_linear_indice(node, idx) - elif "cat" in node.name: - self._assign_cat_indice(node, idx) - elif "matmul" in node.name: - self._assign_matmul_indice(node, idx) - elif "softmax" in node.name: - self._assign_softmax_indice(node, idx) - elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]): - self._assign_elementwise_indice(node, idx) - elif "ones_like" in node.name: - self._assign_ones_like_indice(node, idx) - elif "dropout" in node.name: - self._assign_dropout_indice(node, idx) - elif "einsum" in node.name: - self._assign_einsum_indice(node, idx) - elif "sum" in node.name: - self._assign_sum_indice(node, idx) - elif "layer_norm" in node.name: - self._assign_layernorm_indice(node, idx) - elif "getitem" in node.name: - self._assign_getitem_indice(node, idx) - elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]): + elif any(i == node_name for i in ["size"]): continue else: - raise NotImplementedError(node.name, "function not implemented yet!") - elif node.op == "call_module": - if any(n in node.name for n in ["layernorm", "norm"]): + raise NotImplementedError(node_name, "method not implemented yet!") + elif node.op == "call_function": + if "linear" == node_name: + self._assign_linear_indice(node, idx) + elif "cat" == node_name: + self._assign_cat_indice(node, idx) + elif "matmul" == node_name: + self._assign_matmul_indice(node, idx) + elif "softmax" == node_name: + self._assign_softmax_indice(node, idx) + elif any(n == node_name for n in [ + "mul", + "add", + "sigmoid", + "relu", + "sub", + "truediv", + "pow", + "dropout", + "where", + "tanh", + ]): + self._assign_elementwise_indice(node, idx) + elif "ones_like" == node_name: + self._assign_ones_like_indice(node, idx) + elif "einsum" == node_name: + self._assign_einsum_indice(node, idx) + elif "sum" == node_name: + self._assign_sum_indice(node, idx) + elif "layer_norm" == node_name: self._assign_layernorm_indice(node, idx) - elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]): + elif "getitem" == node_name: + self._assign_getitem_indice(node, idx) + elif "addmm" == node_name: + self._assign_addmm_indice(node, idx) + elif "arange" == node_name: + self._assign_arange_indice(node, idx) + elif "tensor" == node_name: + self._assign_arange_indice(node, idx) + elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): + continue + else: + raise NotImplementedError(node_name, "function not implemented yet!") + elif node.op == "call_module": + node_name = get_module_node_name(node) + if "layernorm" == node_name: + self._assign_layernorm_indice(node, idx) + elif "embedding" == node_name: + self._assign_embedding_indice(node, idx) + elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]): self._assign_elementwise_indice(node, idx) else: - raise NotImplementedError(node.name, "module not implemented yet!") + raise NotImplementedError(node_name, "module not implemented yet!") elif node.op == "get_attr": self._assign_all_indice(node, idx) # get param elif node.op == "output": diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index e87068512..de081b41c 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -1,13 +1,15 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union from torch.fx.node import Node from colossalai.logging import get_dist_logger +NON_COMPUTE_OP = ["placeholder", "get_attr", "output"] +NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"] logger = get_dist_logger() -def get_logger(): +def get_logger() -> Any: return logger @@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node: def is_non_compute_node(node: Node) -> bool: - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]): + if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME): return True if "getitem" in node.name: node_args = flat_list(node.args[1:]) @@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool: return is_non_compute_node(node) -def is_non_compute_node_except_placeholder(node): +def is_non_compute_node_except_placeholder(node: Node) -> bool: if "placeholder" in node.op: return False return is_non_compute_node(node) -def is_non_compute_node_except_placeholder_output(node): +def is_non_compute_node_except_placeholder_output(node: Node) -> bool: if "output" in node.op: return False return is_non_compute_node_except_placeholder(node) -def find_idx_by_name(name, nodes_list): +def find_idx_by_name(name: str, nodes_list: List) -> int: for idx, node in enumerate(nodes_list): if node.name == name: return idx raise RuntimeError("name %s not found in node list" % name) -def delete_free_var_from_last_use(user_to_last_uses): +def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None: for key, value in user_to_last_uses.items(): for n in value: if n.op == "placeholder": user_to_last_uses[key].remove(n) -def find_chunk_all_input_nodes(nodes: List[Node]): +def find_chunk_all_input_nodes(nodes: List[Node]) -> List: """ Find non-compute input and output node names. input nodes are nodes used in the list @@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]): return input_nodes -def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): +def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]: """ Find non-compute input and output node names. input nodes are nodes used in the list @@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): output_nodes.append(node) return input_nodes, output_nodes + + +def get_module_node_name(node: Node) -> str: + """ + get module class name + """ + node_targets = node.target.split(".") + module = node.graph.owning_module + for i in node_targets: + module = getattr(module, i) + module_name = str(module.__class__).split(".")[-1][:-2] + module_name = module_name.lower() + return module_name + + +def get_node_name(node: Node) -> str: + """ + get node name + """ + node_name = node.name + if "_" in node_name: + for i in range(len(node_name) - 1, -1, -1): + if node_name[i] == "_": + node_name = node_name[:i] + break + elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]: + continue + else: + break + return node_name diff --git a/tests/test_autochunk/benchmark_simple_evoformer.py b/tests/test_autochunk/benchmark_simple_evoformer.py deleted file mode 100644 index 8b5d8a8be..000000000 --- a/tests/test_autochunk/benchmark_simple_evoformer.py +++ /dev/null @@ -1,94 +0,0 @@ -import time - -import torch -import torch.fx -from simple_evoformer import base_evoformer, openfold_evoformer - -from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen -from colossalai.fx import ColoTracer -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import MetaTensor - - -def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): - torch.cuda.reset_peak_memory_stats() - now_mem = torch.cuda.memory_allocated() / 1024**2 - - loop = 3 - with torch.no_grad(): - for _ in range(loop // 2 + 1): - if chunk_size: - model(node, pair, chunk_size) - else: - model(node, pair) - torch.cuda.synchronize() - time1 = time.time() - for _ in range(loop): - if chunk_size: - model(node, pair, chunk_size) - else: - model(node, pair) - torch.cuda.synchronize() - time2 = time.time() - - new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem)) - - -def _build_autochunk(model, max_memory, node, pair): - # trace the module and replace codegen - graph = ColoTracer().trace( - model, - meta_args={ - "node": node.to(torch.device("meta")), - "pair": pair.to(torch.device("meta")), - }, - ) - - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace - interp = MetaInfoProp(gm_prop) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - # now run it twice to get meta info in graph module, not necessary - gm = torch.fx.GraphModule(model, graph) - interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - - # set code_gen - codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) - graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph) - gm.recompile() - - # print - # code = graph.python_code("self").src - # print(code) - return gm - - -def benchmark_evoformer(): - # init data and model - msa_len = 128 - pair_len = 256 - node = torch.randn(1, msa_len, pair_len, 256).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - model = base_evoformer().cuda() - - # build autochunk model - # max_memory = 1000 # MB, fit memory mode - max_memory = None # min memory mode - autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair) - - # build openfold - chunk_size = 64 - openfold = openfold_evoformer().cuda() - - # benchmark - _benchmark_evoformer(model, node, pair, "base") - _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) - _benchmark_evoformer(autochunk, node, pair, "autochunk") - - -if __name__ == "__main__": - benchmark_evoformer() diff --git a/tests/test_autochunk/test_alphafold/test_alphafold_utils.py b/tests/test_autochunk/test_alphafold/test_alphafold_utils.py new file mode 100644 index 000000000..b05191d2b --- /dev/null +++ b/tests/test_autochunk/test_alphafold/test_alphafold_utils.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.autochunk.utils import flat_list +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: List, + concrete_args: List = None, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + if concrete_args is None: + concrete_args = [] + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_result = None; chunk_size = None;" in code + + # assert result + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + model.cuda() + with torch.no_grad(): + out_gm = gm(*inputs) + out_model = model(*inputs) + out_gm = flat_list(out_gm) + out_model = flat_list(out_model) + for out_gm_i, out_model_i in zip(out_gm, out_model): + assert torch.allclose(out_gm_i, out_model_i, + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm_i - out_model_i)) + + return chunks + + +def run_test( + rank: int, + data_args: tuple, + max_memory: int, + get_model: Any, + get_data: Any, + print_code: bool, + print_mem: bool, + print_progress: bool, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = get_model() + meta_args, concrete_args = get_data(*data_args) + chunks = assert_codegen_run( + model, + meta_args=meta_args, + concrete_args=concrete_args, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_block.py b/tests/test_autochunk/test_alphafold/test_evoformer_block.py new file mode 100644 index 000000000..787067daa --- /dev/null +++ b/tests/test_autochunk/test_alphafold/test_evoformer_block.py @@ -0,0 +1,95 @@ +from functools import partial +from typing import Dict, List, Tuple + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True +except: + HAS_REPO = False + +from test_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + + +def get_model(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_mask_trans", True)] + return meta_args, concrete_args + + +def get_chunk_target() -> Dict: + return { + None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)], + 20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)], + 24: [(118, 123)], + } + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 20, 24]) +@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +def test_evoformer_block(data_args, max_memory): + run_func = partial( + run_test, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + get_chunk_target=get_chunk_target, + print_code=False, + print_mem=False, + print_progress=False, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=20, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_stack.py b/tests/test_autochunk/test_alphafold/test_evoformer_stack.py new file mode 100644 index 000000000..45d8e7ac8 --- /dev/null +++ b/tests/test_autochunk/test_alphafold/test_evoformer_stack.py @@ -0,0 +1,90 @@ +from functools import partial +from typing import List, Tuple + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True +except: + HAS_REPO = False + +from test_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + + +def get_model(): + model = EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_mask_trans", True)] + return meta_args, concrete_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 20, 24]) +@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +def test_evoformer_stack(data_args, max_memory): + run_func = partial( + run_test, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=20, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_alphafold/test_extramsa_block.py b/tests/test_autochunk/test_alphafold/test_extramsa_block.py new file mode 100644 index 000000000..a2b72ed1a --- /dev/null +++ b/tests/test_autochunk/test_alphafold/test_extramsa_block.py @@ -0,0 +1,96 @@ +from functools import partial +from typing import Dict, List, Tuple + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True +except: + HAS_REPO = False +from test_alphafold_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + + +def get_model(): + model = ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ).eval().cuda() + return model + + +def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + meta_args = [ + ("m", node), + ("z", pair), + ("msa_mask", node_mask), + ("pair_mask", pair_mask), + ] + concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)] + return meta_args, concrete_args + + +def get_chunk_target() -> Dict: + return { + None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250), + (33, 46)], + 20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)], + 24: [(126, 131)], + } + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 20, 24]) +@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +def test_extramsa_block(data_args, max_memory): + run_func = partial( + run_test, + data_args=data_args, + max_memory=max_memory, + get_model=get_model, + get_data=get_data, + print_code=False, + print_mem=False, + print_progress=False, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data_args=(32, 64), + max_memory=20, + get_model=get_model, + get_data=get_data, + get_chunk_target=get_chunk_target, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_diffuser/test_diffuser_utils.py b/tests/test_autochunk/test_diffuser/test_diffuser_utils.py new file mode 100644 index 000000000..0f3d22dc5 --- /dev/null +++ b/tests/test_autochunk/test_diffuser/test_diffuser_utils.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: List, + concrete_args: List = None, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + if concrete_args is None: + concrete_args = [] + model = model() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_result = None; chunk_size = None;" in code + + # assert result + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + out_gm = gm(*inputs) + out_model = model(*inputs) + assert torch.allclose(out_gm["sample"], out_model["sample"], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm["sample"] - out_model["sample"])) + + return chunks + + +def run_test( + rank: int, + model: Any, + data: tuple, + max_memory: int, + print_code: bool, + print_mem: bool, + print_progress: bool, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + meta_args, concrete_args = data + chunks = assert_codegen_run( + model, + meta_args=meta_args, + concrete_args=concrete_args, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) + + gpc.destroy() diff --git a/tests/test_autochunk/test_diffuser/test_unet.py b/tests/test_autochunk/test_diffuser/test_unet.py new file mode 100644 index 000000000..db154b4bb --- /dev/null +++ b/tests/test_autochunk/test_diffuser/test_unet.py @@ -0,0 +1,70 @@ +from functools import partial +from typing import List, Tuple + +import pytest +import torch +import torch.multiprocessing as mp + +try: + from diffusers import UNet2DModel + MODELS = [UNet2DModel] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_diffuser_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + +BATCH_SIZE = 2 +SEQ_LENGTH = 5 +HEIGHT = 224 +WIDTH = 224 +IN_CHANNELS = 3 +LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) + + +def get_data(shape: tuple) -> Tuple[List, List]: + sample = torch.randn(shape) + meta_args = [ + ("sample", sample), + ] + concrete_args = [("timestep", 50)] + return meta_args, concrete_args + + +@pytest.mark.skipif( + True, + reason="not implemented", +) +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("shape", [LATENTS_SHAPE]) +@pytest.mark.parametrize("max_memory", [64]) +def test_evoformer_block(model, shape, max_memory): + run_func = partial( + run_test, + max_memory=max_memory, + model=model, + data=get_data(shape), + print_code=False, + print_mem=False, + print_progress=False, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data=get_data(LATENTS_SHAPE), + max_memory=64, + model=UNet2DModel, + print_code=False, + print_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py deleted file mode 100644 index ba6a57a51..000000000 --- a/tests/test_autochunk/test_evoformer_codegen.py +++ /dev/null @@ -1,163 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.fx -import torch.multiprocessing as mp - -try: - from fastfold.model.nn.evoformer import EvoformerBlock - HAS_REPO = True -except: - HAS_REPO = False - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port - -if CODEGEN_AVAILABLE and is_compatible_with_meta(): - from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen - from colossalai.fx.profiler import MetaTensor - from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace - - -def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): - # for memory test - # model = model.cuda() - # torch.cuda.reset_peak_memory_stats() - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # with torch.no_grad(): - # node1 = node.clone() - # pair1 = pair.clone() - # node_mask1 = node_mask.clone() - # pair_mask1 = pair_mask.clone() - # gm(node1, pair1, node_mask1, pair_mask1) - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) - - # test forward - model = model.cuda() - with torch.no_grad(): - non_fx_out = model(node, pair, node_mask, pair_mask) - fx_out = gm(node, pair, node_mask, pair_mask) - - assert torch.allclose(non_fx_out[0], fx_out[0], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[0] - fx_out[0])) - assert torch.allclose(non_fx_out[1], fx_out[1], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[1] - fx_out[1])) - - -def _build_openfold(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).eval().cuda() - return model - - -def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): - # launch colossalai - colossalai.launch( - config={}, - rank=rank, - world_size=1, - host="localhost", - port=free_port(), - backend="nccl", - ) - - # build model and input - model = _build_openfold() - node = torch.randn(1, msa_len, pair_len, 256).cuda() - node_mask = torch.randn(1, msa_len, pair_len).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - pair_mask = torch.randn(1, pair_len, pair_len).cuda() - - # trace the meta graph and setup codegen - meta_graph = symbolic_trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_mask_trans": True, - }, - ) - interp = MetaInfoProp(meta_graph) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), - MetaTensor(pair, fake_device="cuda:0"), - MetaTensor(node_mask, fake_device="cuda:0"), - MetaTensor(pair_mask, fake_device="cuda:0"), - ) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) - - # trace and recompile - # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer - graph = ColoTracer().trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_mask_trans": True, - }, - ) - graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph, ckpt_codegen=False) - gm.recompile() - - # assert we have inserted chunk - code = graph.python_code("self").src - # print(code) - assert "chunk_result = None; chunk_size = None;" in code - - _test_fwd(model, gm, node, pair, node_mask, pair_mask) - gpc.destroy() - - -@pytest.mark.skipif( - not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), - reason="torch version is lower than 1.12.0", -) -@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) -@pytest.mark.parametrize("msa_len", [32]) -@pytest.mark.parametrize("pair_len", [64]) -def test_evoformer_codegen(msa_len, pair_len, max_memory): - run_func = partial( - _test_evoformer_codegen, - msa_len=msa_len, - pair_len=pair_len, - max_memory=max_memory, - ) - mp.spawn(run_func, nprocs=1) - - -if __name__ == "__main__": - _test_evoformer_codegen(0, 32, 64, 24) diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py deleted file mode 100644 index 5fabb2702..000000000 --- a/tests/test_autochunk/test_evoformer_stack_codegen.py +++ /dev/null @@ -1,163 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.fx -import torch.multiprocessing as mp - -try: - from fastfold.model.nn.evoformer import EvoformerStack - HAS_REPO = True -except: - HAS_REPO = False - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port - -if CODEGEN_AVAILABLE and is_compatible_with_meta(): - from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen - from colossalai.fx.profiler import MetaTensor - from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace - - -def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): - # for memory test - # model = model.cuda() - # torch.cuda.reset_peak_memory_stats() - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # with torch.no_grad(): - # node1 = node.clone() - # pair1 = pair.clone() - # node_mask1 = node_mask.clone() - # pair_mask1 = pair_mask.clone() - # gm(node1, pair1, node_mask1, pair_mask1, None) - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) - - # test forward - model = model.cuda() - with torch.no_grad(): - non_fx_out = model(node, pair, node_mask, pair_mask, None) - fx_out = gm(node, pair, node_mask, pair_mask, None) - - assert torch.allclose(non_fx_out[0], fx_out[0], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[0] - fx_out[0])) - assert torch.allclose(non_fx_out[1], fx_out[1], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[1] - fx_out[1])) - - -def _build_openfold(): - model = EvoformerStack( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - c_s=384, - no_heads_msa=8, - no_heads_pair=4, - no_blocks=2, # 48 - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.25, - blocks_per_ckpt=None, - inf=1000000000.0, - eps=1e-08, - clear_cache_between_blocks=False, - is_multimer=False, - ).eval().cuda() - return model - - -def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory): - # launch colossalai - colossalai.launch( - config={}, - rank=rank, - world_size=1, - host="localhost", - port=free_port(), - backend="nccl", - ) - - # build model and input - model = _build_openfold() - node = torch.randn(1, msa_len, pair_len, 256).cuda() - node_mask = torch.randn(1, msa_len, pair_len).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - pair_mask = torch.randn(1, pair_len, pair_len).cuda() - - # trace the meta graph and setup codegen - meta_graph = symbolic_trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_mask_trans": True, - }, - ) - interp = MetaInfoProp(meta_graph) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"), - MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False) - - # trace and recompile - # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer - graph = ColoTracer().trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_mask_trans": True, - }, - ) - graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph, ckpt_codegen=False) - gm.recompile() - - # assert we have inserted chunk - code = graph.python_code("self").src - # print(code) - assert "chunk_result = None; chunk_size = None;" in code - - _test_fwd(model, gm, node, pair, node_mask, pair_mask) - gpc.destroy() - - -@pytest.mark.skipif( - not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), - reason="torch version is lower than 1.12.0", -) -@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) -@pytest.mark.parametrize("msa_len", [32]) -@pytest.mark.parametrize("pair_len", [64]) -def test_evoformer_stack_codegen(msa_len, pair_len, max_memory): - run_func = partial( - _test_evoformer_stack_codegen, - msa_len=msa_len, - pair_len=pair_len, - max_memory=max_memory, - ) - mp.spawn(run_func, nprocs=1) - - -if __name__ == "__main__": - _test_evoformer_stack_codegen(0, 32, 64, None) diff --git a/tests/test_autochunk/test_extramsa_codegen.py b/tests/test_autochunk/test_extramsa_codegen.py deleted file mode 100644 index 2a41452a2..000000000 --- a/tests/test_autochunk/test_extramsa_codegen.py +++ /dev/null @@ -1,164 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.fx -import torch.multiprocessing as mp - -try: - from fastfold.model.nn.evoformer import ExtraMSABlock - HAS_REPO = True -except: - HAS_REPO = False - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port - -if CODEGEN_AVAILABLE and is_compatible_with_meta(): - from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen - from colossalai.fx.profiler import MetaTensor - from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace - - -def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): - # for memory test - # model = model.cuda() - # torch.cuda.reset_peak_memory_stats() - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # with torch.no_grad(): - # node1 = node.clone() - # pair1 = pair.clone() - # node_mask1 = node_mask.clone() - # pair_mask1 = pair_mask.clone() - # gm(node1, pair1, node_mask1, pair_mask1) - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) - - # test forward - model = model.cuda() - with torch.no_grad(): - non_fx_out = model(node, pair, node_mask, pair_mask) - fx_out = gm(node, pair, node_mask, pair_mask) - - assert torch.allclose(non_fx_out[0], fx_out[0], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[0] - fx_out[0])) - assert torch.allclose(non_fx_out[1], fx_out[1], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[1] - fx_out[1])) - - -def _build_openfold(): - model = ExtraMSABlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - ckpt=False, - is_multimer=False, - ).eval().cuda() - return model - - -def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory): - # launch colossalai - colossalai.launch( - config={}, - rank=rank, - world_size=1, - host="localhost", - port=free_port(), - backend="nccl", - ) - - # build model and input - model = _build_openfold() - node = torch.randn(1, msa_len, pair_len, 256).cuda() - node_mask = torch.randn(1, msa_len, pair_len).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - pair_mask = torch.randn(1, pair_len, pair_len).cuda() - - # trace the meta graph and setup codegen - meta_graph = symbolic_trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_chunk_logits": 1024, - }, - ) - interp = MetaInfoProp(meta_graph) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), - MetaTensor(pair, fake_device="cuda:0"), - MetaTensor(node_mask, fake_device="cuda:0"), - MetaTensor(pair_mask, fake_device="cuda:0"), - ) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False) - - # trace and recompile - # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer - graph = ColoTracer().trace( - model, - meta_args={ - "m": node.to(torch.device("meta")), - "z": pair.to(torch.device("meta")), - "msa_mask": node_mask.to(torch.device("meta")), - "pair_mask": pair_mask.to(torch.device("meta")), - }, - concrete_args={ - "chunk_size": None, - "_chunk_logits": 1024, - }, - ) - graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph, ckpt_codegen=False) - gm.recompile() - - # assert we have inserted chunk - code = graph.python_code("self").src - # print(code) - assert "chunk_result = None; chunk_size = None;" in code - - _test_fwd(model, gm, node, pair, node_mask, pair_mask) - gpc.destroy() - - -@pytest.mark.skipif( - not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), - reason="torch version is lower than 1.12.0", -) -@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) -@pytest.mark.parametrize("msa_len", [32]) -@pytest.mark.parametrize("pair_len", [64]) -def test_extramsa_codegen(msa_len, pair_len, max_memory): - run_func = partial( - _test_extramsa_codegen, - msa_len=msa_len, - pair_len=pair_len, - max_memory=max_memory, - ) - mp.spawn(run_func, nprocs=1) - - -if __name__ == "__main__": - _test_extramsa_codegen(0, 32, 64, None) diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py deleted file mode 100644 index 7fe149c57..000000000 --- a/tests/test_autochunk/test_simple_evoformer_codegen.py +++ /dev/null @@ -1,104 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.fx -import torch.multiprocessing as mp - -try: - from simple_evoformer import base_evoformer - HAS_REPO = True -except: - HAS_REPO = False - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.fx import ColoTracer, symbolic_trace -from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port - -if CODEGEN_AVAILABLE and is_compatible_with_meta(): - from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen - from colossalai.fx.profiler import MetaTensor - - -def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - with torch.no_grad(): - non_fx_out = model(node, pair) - fx_out = gm(node, pair) - - assert torch.allclose(non_fx_out[0], fx_out[0], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[0] - fx_out[0])) - assert torch.allclose(non_fx_out[1], fx_out[1], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(non_fx_out[1] - fx_out[1])) - - -def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): - # launch colossalai - colossalai.launch( - config={}, - rank=rank, - world_size=1, - host="localhost", - port=free_port(), - backend="nccl", - ) - - # build model and input - model = base_evoformer().cuda() - node = torch.randn(1, msa_len, pair_len, 256).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - - # meta info prop - meta_graph = symbolic_trace(model, - meta_args={ - "node": node.to(torch.device("meta")), - "pair": pair.to(torch.device("meta")), - }) # must use symbolic_trace - interp = MetaInfoProp(meta_graph) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) - - # trace the module and replace codegen - graph = ColoTracer().trace( - model, - meta_args={ - "node": node.to(torch.device("meta")), - "pair": pair.to(torch.device("meta")), - }, - ) - graph.set_codegen(codegen) - gm = ColoGraphModule(model, graph, ckpt_codegen=False) - gm.recompile() - - # assert we have inserted chunk - code = graph.python_code("self").src - # print(code) - assert "chunk_result = None; chunk_size = None;" in code - - _test_fwd(model, gm, node, pair) - gpc.destroy() - - -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), - reason='torch version is lower than 1.12.0') -@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) -@pytest.mark.parametrize("msa_len", [32]) -@pytest.mark.parametrize("pair_len", [64]) -def test_simple_evoformer_codegen(msa_len, pair_len, max_memory): - run_func = partial( - _test_simple_evoformer_codegen, - msa_len=msa_len, - pair_len=pair_len, - max_memory=max_memory, - ) - mp.spawn(run_func, nprocs=1) - - -if __name__ == "__main__": - _test_simple_evoformer_codegen(0, 32, 64, 25) diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py deleted file mode 100644 index 89f28d625..000000000 --- a/tests/test_autochunk/test_simple_evoformer_search.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.fx -import torch.multiprocessing as mp - -try: - from simple_evoformer import base_evoformer - HAS_REPO = True -except: - HAS_REPO = False - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.fx import symbolic_trace -from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port - -if CODEGEN_AVAILABLE and is_compatible_with_meta(): - from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen - from colossalai.fx.profiler import MetaTensor - - -def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): - found_regions = [i["region"] for i in chunk_infos] - - if msa_len == 32 and pair_len == 64: - if max_memory is None: - target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191), - (161, 166), (198, 203), (7, 57)] - elif max_memory == 20: - target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)] - elif max_memory == 25: - target_regions = [(144, 154), (369, 370)] - elif max_memory == 30: - target_regions = [(144, 154)] - else: - raise NotImplementedError() - else: - raise NotImplementedError() - - assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % ( - str(found_regions), - str(target_regions), - ) - - -def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory): - # launch colossalai - colossalai.launch( - config={}, - rank=rank, - world_size=1, - host="localhost", - port=free_port(), - backend="nccl", - ) - - # build model and input - model = base_evoformer().cuda() - node = torch.randn(1, msa_len, pair_len, 256).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - - meta_graph = symbolic_trace(model, - meta_args={ - "node": node.to(torch.device("meta")), - "pair": pair.to(torch.device("meta")), - }) # must use symbolic_trace - interp = MetaInfoProp(meta_graph) - interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) - codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) - chunk_infos = codegen.chunk_infos - assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len) - - gpc.destroy() - - -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), - reason="torch version is lower than 1.12.0") -@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) -@pytest.mark.parametrize("msa_len", [32]) -@pytest.mark.parametrize("pair_len", [64]) -def test_simple_evoformer_search(msa_len, pair_len, max_memory): - run_func = partial( - _test_simple_evoformer_search, - msa_len=msa_len, - pair_len=pair_len, - max_memory=max_memory, - ) - mp.spawn(run_func, nprocs=1) - - -if __name__ == "__main__": - _test_simple_evoformer_search(0, 32, 64, 20) diff --git a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py new file mode 100644 index 000000000..0ba8f89c2 --- /dev/null +++ b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py @@ -0,0 +1,65 @@ +from functools import partial +from typing import List, Tuple + +import pytest +import torch +import torch.multiprocessing as mp + +try: + from transformers import GPT2Config, GPT2Model + MODELS = [GPT2Model] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_transformer_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + +BATCH_SIZE = 2 +SEQ_LENGTH = 256 + + +def get_data(shape: tuple) -> Tuple[List, List]: + input_ids = torch.zeros(shape, dtype=torch.int64) + token_type_ids = torch.zeros(shape, dtype=torch.int64) + attention_mask = torch.ones(shape, dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + concrete_args = {"past_key_values": None} + sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"] + return meta_args, concrete_args, sequence + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) +@pytest.mark.parametrize("max_memory", [None, 4.5, 5]) +def test_gpt(model, shape, max_memory): + run_func = partial( + run_test, + data=get_data(shape), + max_memory=max_memory, + model=model, + config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4), + print_code=False, + print_mem=False, + print_progress=False, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=True, + print_mem=True, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_transformer/test_transformer_utils.py b/tests/test_autochunk/test_transformer/test_transformer_utils.py new file mode 100644 index 000000000..d33fc04c5 --- /dev/null +++ b/tests/test_autochunk/test_transformer/test_transformer_utils.py @@ -0,0 +1,123 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + data: tuple, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + meta_args, concrete_args, sequence = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + interp = MetaInfoProp(meta_graph) + meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + concrete_args={k: v for k, v in concrete_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_result = None; chunk_size = None;" in code + + # assert result + inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + out_gm = gm(*inputs) + out_model = model(*inputs) + for k in out_model.keys(): + if torch.is_tensor(out_gm[k]): + assert torch.equal( + out_model[k], out_gm[k] + ), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}' + + return chunks + + +def run_test( + rank: int, + model: Any, + config: Any, + data: tuple, + max_memory: int, + print_code: bool, + print_mem: bool, + print_progress: bool, + get_chunk_target: Any = None, +) -> None: + model = model(config=config) + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + chunks = assert_codegen_run( + model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) + + gpc.destroy()