diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 7f2aac42b..9d83f1057 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) + if 'activation_checkpoint' in user_node.meta: + shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint'] new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) @@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs + + if 'activation_checkpoint' in node.meta: + comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + return gm + + +def _act_annotataion_pass(gm: torch.fx.GraphModule): + """ + This pass is used to add the act annotation to the new inserted nodes. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + if not hasattr(node.meta, 'activation_checkpoint'): + from .runtime_preparation_pass import size_processing + + user_act_annotation = -1 + input_act_annotation = -1 + for user_node in node.users.keys(): + if 'activation_checkpoint' in user_node.meta: + user_act_annotation = user_node.meta['activation_checkpoint'] + break + for input_node in node._input_nodes.keys(): + if 'activation_checkpoint' in input_node.meta: + input_act_annotation = input_node.meta['activation_checkpoint'] + break + if user_act_annotation == input_act_annotation and user_act_annotation != -1: + node.meta['activation_checkpoint'] = user_act_annotation + return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index f9b890263..1c25e4c94 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data + if 'activation_checkpoint' in node.meta: + size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] user_list = list(node.users.keys()) for user in user_list: diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 8c24c0d7b..387a682a1 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import ( ) from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module): into the forward function. ''' - def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], + def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): ''' Args: @@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): return strategies_constructor -def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): +def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): ''' This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. @@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, return solution -def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, +def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh, strategies_constructor: StrategiesConstructor): ''' This method is used to transform the original graph to the sharded graph. @@ -197,10 +198,10 @@ def initialize_model(model: nn.Module, solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. ''' - tracer = ColoTracer() + tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(root=model, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) + gm = ColoGraphModule(model, graph, model.__class__.__name__) gm.recompile() strategies_constructor = build_strategy_constructor(graph, device_mesh) if load_solver_solution: diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index e8af9bde8..ceccb9a9f 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> return new_shape -def _gen_loop_start( - chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2 -) -> str: +def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str: """ Generate chunk loop start @@ -72,9 +70,8 @@ def _gen_loop_start( out_shape = get_node_shape(chunk_output) out_str = str(list(out_shape)) context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" - % (out_str, input_node.name, input_node.name, chunk_size) - ) + "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % + (out_str, input_node.name, input_node.name, chunk_size)) context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) return context @@ -105,26 +102,17 @@ def _gen_loop_end( chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim( - chunk_outputs_dim, "chunk_idx", chunk_output_shape - ) + chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) context = " chunk_result%s = %s; %s = None\n" % ( chunk_slice, chunk_outputs_name, chunk_outputs_name, ) - context += ( - chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None" - ) + context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") # determine if its the last use for chunk input for chunk_input in chunk_inputs + chunk_non_compute_inputs: - if all( - [ - find_idx_by_name(user.name, node_list) <= chunk_outputs_idx - for user in chunk_input.users.keys() - ] - ): + if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]): context += "; %s = None" % chunk_input.name context += "\n" @@ -171,17 +159,10 @@ def _replace_ones_like( chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) + body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) return body @@ -198,12 +179,8 @@ def _replace_input_node( for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) + chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node)) + body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice) return body @@ -236,14 +213,10 @@ def emit_code_with_chunk( chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [ - i["inputs_non_chunk"] for i in chunk_infos - ] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim - chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ - j.name for i in chunk_inputs_non_chunk for j in i - ] + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs chunk_outputs = [i["outputs"][0] for i in chunk_infos] @@ -267,23 +240,16 @@ def emit_code_with_chunk( chunk_outputs[region_idx], chunk_outputs_dim[region_idx], chunk_infos[region_idx]["chunk_size"], - ) - ) + )) if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body = _replace_input_node( - chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body - ) + body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body) # ones like - body = _replace_ones_like( - search_chunk, chunk_infos, region_idx, node_idx, node, body - ) + body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size - body[-1] = _replace_reshape_size( - body[-1], node.name, chunk_infos[region_idx]["reshape_size"] - ) + body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: @@ -300,8 +266,7 @@ def emit_code_with_chunk( chunk_outputs[region_idx], chunk_outputs_dim[region_idx], node_list, - ) - ) + )) within_chunk_region = False node_idx += 1 @@ -310,18 +275,14 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None, print_mem=False): super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) self.chunk_infos = self.search_chunk.search_region() - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} @@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE: Returns: the global name that should be used to reference 'obj' in generated source. """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device + if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE: return global_name # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) # Pre-fill the globals table with registered builtins. for name, (_, obj) in _custom_builtins.items(): @@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE: # Common case: this is a regular module name like 'foo.bar.baz' return add_global(typename, o) - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. if isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE: nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) body.append(f"; {to_delete_str}\n") else: body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" - ) + maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) + maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") raw_name = node.target.replace("*", "") if raw_name != repr(node): body.append(f"{repr(node)} = {raw_name}\n") @@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE: assert isinstance(node.target, str) body.append( f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) + f"({_format_args(node.args[1:], node.kwargs)})") return elif node.op == "call_function": assert callable(node.target) # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): + if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) + if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): + body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): + if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) + and node.args[1].isidentifier() and len(node.args) == 2): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") return body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" - ) + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return elif node.op == "call_module": assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") return elif node.op == "get_attr": assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" - ) + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return elif node.op == "output": if node.type is not None: @@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE: if len(wrapped_fns) > 0: wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] - ) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: wrap_stmts = "" diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 1e2e6dc12..ec1e012be 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -10,6 +10,7 @@ from .utils import ( class TraceFlow(object): + def __init__(self, trace_indice: TraceIndice) -> None: self.trace_indice = trace_indice @@ -28,9 +29,7 @@ class TraceFlow(object): start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] - sorted_source = sorted( - end_node_trace_source.items(), key=lambda d: d[0], reverse=True - ) + sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) for node_idx, node_dim in sorted_source: if node_idx == start_node_idx and start_dim in node_dim: return True @@ -70,10 +69,8 @@ class TraceFlow(object): input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) node_trace_source = self.trace_indice._find_source_trace_from_node(node) for node_dim in range(len(get_node_shape(node))): - if ( - input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx] - ): + if (input_node_idx in node_trace_source[node_dim] + and input_dim[0] in node_trace_source[node_dim][input_node_idx]): return node_dim return None @@ -81,15 +78,11 @@ class TraceFlow(object): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): - inherit_dim = self._find_inherit_dim( - input_node, v, self.trace_indice.node_list[k] - ) + inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k]) if inherit_dim: input_dim_after_node[k] = inherit_dim - for node in self.trace_indice.node_list[ - chunk_infos["region"][0] : chunk_infos["region"][1] + 1 - ]: + for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]: if is_non_compute_node_except_placeholder(node): continue count = 0 @@ -159,9 +152,7 @@ class TraceFlow(object): if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False - all_node_info[arg_node]["fix_dim"] = list( - set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) - ) + all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) # else add it to list else: all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} @@ -170,9 +161,7 @@ class TraceFlow(object): return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [ - self.trace_indice.node_list[end_idx] - ] # start from the last node + cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -183,12 +172,8 @@ class TraceFlow(object): cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] if cur_node_chunk_dim: - cur_node_compute = self.trace_indice._find_compute_trace_from_node( - cur_node - ) - cur_node_source = self.trace_indice._find_source_trace_from_node( - cur_node - ) + cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node) + cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node) else: cur_node_compute = cur_node_source = None @@ -215,15 +200,9 @@ class TraceFlow(object): return None if len(arg_list) == 2: - if any(i in cur_node.name for i in ["add", "mul"]): + if any(i in cur_node.name for i in ["add", "mul", "truediv"]): for arg in arg_list: - if not ( - start_idx - <= find_idx_by_name( - arg.name, self.trace_indice.node_list - ) - < end_idx - ): + if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): continue arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_fix_dim = all_node_info[arg]["fix_dim"] @@ -249,9 +228,7 @@ class TraceFlow(object): remove_inputs = [] for input_node in inputs: input_dict = {} - input_node_idx = find_idx_by_name( - input_node.name, self.trace_indice.node_list - ) + input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) for user in input_node.users.keys(): if is_non_compute_node(user): continue @@ -259,9 +236,7 @@ class TraceFlow(object): if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: - user_source = self.trace_indice._find_source_trace_from_node( - user - )[chunk_dim] + user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: @@ -284,7 +259,7 @@ class TraceFlow(object): maybe_prepose_nodes.sort( key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), reverse=True, - ) # from last node to first node + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: @@ -305,13 +280,8 @@ class TraceFlow(object): if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop - if not ( - start_idx - <= find_idx_by_name( - cur_prepose_node_arg.name, self.trace_indice.node_list - ) - < end_idx - ): + if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) < + end_idx): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: @@ -335,15 +305,13 @@ class TraceFlow(object): if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list) - ) + prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)) return prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1] + chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) @@ -354,9 +322,7 @@ class TraceFlow(object): return chunk_info def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.trace_indice.node_list[start_idx : end_idx + 1] - ) + inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1]) # only single ouput if len(outputs) > 1: return None @@ -367,9 +333,7 @@ class TraceFlow(object): return None # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim( - inputs, start_idx, end_idx, all_node_info - ) + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) if inputs is None: return None @@ -385,9 +349,7 @@ class TraceFlow(object): } # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes( - all_node_info, start_idx, end_idx - ) + chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx) # find non chunk inputs chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) @@ -400,10 +362,8 @@ class TraceFlow(object): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} - chunk_shape = get_node_shape(chunk_info["outputs"][0])[ - chunk_info["outputs_dim"] - ] - for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]: + chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] + for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] reshape_log = self.trace_indice.indice_view_list[node] @@ -413,8 +373,6 @@ class TraceFlow(object): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = ( - "min(chunk_size, %d - chunk_idx)" % chunk_shape - ) + reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape) chunk_info["reshape_size"] = reshape_size return chunk_info diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 1e16ab9bd..5a5d15e0a 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -3,7 +3,7 @@ from typing import Dict, List, Tuple from torch.fx.node import Node -from .utils import find_idx_by_name, get_node_shape +from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list class TraceIndice(object): @@ -79,9 +79,7 @@ class TraceIndice(object): node_from_trace = self._find_trace_from_node(node_from) node_to_trace = self._find_trace_from_node(node_to) node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim] - node_to_trace["compute"][node_to_dim] = copy.deepcopy( - node_from_trace["compute"][node_from_dim] - ) + node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim]) self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True) def _inherit_all_computation(self, node_from, node_to): @@ -209,7 +207,7 @@ class TraceIndice(object): node_idx (int) """ if input_node == None: - input_node = node.args[0] + input_node = find_first_tensor_arg(node) input_node_idx = find_idx_by_name(input_node.name, self.node_list) input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"] @@ -227,6 +225,8 @@ class TraceIndice(object): node_idx (int) """ shape = node.meta["tensor_meta"].shape + if shape is None: + return new_trace = [] for _ in shape: new_trace.append(self._add_indice()) @@ -259,7 +259,7 @@ class TraceIndice(object): node (node) node_idx (int) """ - permute_dim = node.args[1:] + permute_dim = unflat_list(node.args[1:]) input_node = node.args[0] self._assign_indice_as_input(node, node_idx, input_node) @@ -359,6 +359,15 @@ class TraceIndice(object): left, right = patterns.split("->") left = left.split(",") + if '...' in right: + replace_list = "!@#$%^&*" + target_len = len(get_node_shape(node)) + add_len = target_len - len(right) + 3 + replace_str = replace_list[:add_len] + right = right.replace("...", replace_str) + for ll in range(len(left)): + left[ll] = left[ll].replace("...", replace_str) + all_index = [] for i in left: for c in i: @@ -369,9 +378,7 @@ class TraceIndice(object): for left_idx, left_str in enumerate(left): if right_indice in left_str: source_idx = left_str.index(right_indice) - self._inherit_indice( - input_nodes[left_idx], source_idx, node, right_idx - ) + self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx) def _assign_softmax_indice(self, node, idx): """ @@ -440,11 +447,12 @@ class TraceIndice(object): origin_node = node.args[0] origin_shape = origin_node.meta["tensor_meta"].shape target_shape = [] - for i in range(1, len(node.args)): - if isinstance(node.args[i], int): - target_shape.append(node.args[i]) + unflated_args = unflat_list(node.args) + for i in range(1, len(unflated_args)): + if isinstance(unflated_args[i], int): + target_shape.append(unflated_args[i]) else: - target_shape.append(node.args[i].meta["fwd_out"][0]) + target_shape.append(unflated_args[i].meta["fwd_out"][0]) # compute the value of -1 if -1 in target_shape: @@ -472,13 +480,7 @@ class TraceIndice(object): dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] self._del_dim(node_idx, -1) else: - raise NotImplementedError( - "shape" - + str(origin_shape) - + "and" - + str(target_shape) - + "view not implemented" - ) + raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented") # get new indice origin_trace = self._find_indice_trace_from_node(origin_node) @@ -521,6 +523,8 @@ class TraceIndice(object): self._assign_unsqueeze_indice(node, idx) elif any(i in node.name for i in ["to", "contiguous"]): self._assgin_no_change_indice(node, idx) + elif "new_ones" in node.name: + self._assign_ones_like_indice(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == "call_function": @@ -530,7 +534,7 @@ class TraceIndice(object): self._assign_matmul_indice(node, idx) elif "softmax" in node.name: self._assign_softmax_indice(node, idx) - elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]): + elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]): self._assign_elementwise_indice(node, idx) elif "ones_like" in node.name: self._assign_ones_like_indice(node, idx) @@ -538,21 +542,21 @@ class TraceIndice(object): self._assign_dropout_indice(node, idx) elif "einsum" in node.name: self._assign_einsum_indice(node, idx) - elif "getattr" in node.name: - continue # get attr like shape - elif "getitem" in node.name: - continue # get item in list + elif "layer_norm" in node.name: + self._assign_layernorm_indice(node, idx) + elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]): + continue else: - raise NotImplementedError( - node.name, "function not implemented yet!" - ) + raise NotImplementedError(node.name, "function not implemented yet!") elif node.op == "call_module": if any(n in node.name for n in ["layernorm", "norm"]): self._assign_layernorm_indice(node, idx) + elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]): + self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node.name, "module not implemented yet!") elif node.op == "get_attr": - self._assign_all_indice(node, idx) # get param + self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index b62a6600a..5f3ea3bf4 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple from torch.fx.node import Node +def unflat_list(inputs): + """ + unflat a list by recursion + """ + res = [] + for i in inputs: + if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): + res.extend(unflat_list(i)) + else: + res.append(i) + return res + + +def find_first_tensor_arg(node): + """ + Find the first input tensor arg for a node + """ + for arg in node.args: + if type(arg) == type(node): + return arg + raise RuntimeError() + + def is_non_compute_node(node): if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + i in node.name for i in ["getitem", "getattr"]): return True return False @@ -18,17 +40,13 @@ def get_node_shape(node): def is_non_compute_node_except_placeholder(node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]): return True return False def is_non_compute_node_except_placeholder_output(node): - if any(i in node.op for i in ["get_attr"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): + if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]): return True return False @@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if ( - input_node not in nodes - and input_node not in input_nodes - and not is_non_compute_node_except_placeholder(input_node) - ): + if (input_node not in nodes and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node)): input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output for node in nodes: for output_node in node.users.keys(): - if ( - output_node not in nodes - and node not in output_nodes - and not is_non_compute_node_except_placeholder_output(output_node) - ): + if (output_node not in nodes and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node)): output_nodes.append(node) return input_nodes, output_nodes diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 1c39dc247..6bd612ad2 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, + aten.sub.Tensor, + aten.sub_.Tensor, # activation op aten.hardswish.default, @@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.where.self, aten.zero_.default, aten.zeros_like.default, - ] + aten.fill_.Scalar + ] # yapf: disable for op in zero_flop_aten: flop_mapping[op] = zero_flop_jit diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py index 9e0c05d89..ec322a78b 100644 --- a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py +++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py @@ -7,7 +7,6 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - self._grads = dict() self._params = dict() self._num_elements_in_bucket = dict() @@ -19,25 +18,24 @@ class BucketStore(BaseStore): def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): self._num_elements_in_bucket[reduce_rank] += num_elements - def add_grad(self, tensor, reduce_rank: int = None): - self._grads[reduce_rank].append(tensor) - def add_param(self, tensor, reduce_rank: int = None): self._params[reduce_rank].append(tensor) def reset(self): keys = [None] + list(range(self._world_size)) - self._grads = {rank: [] for rank in keys} self._params = {rank: [] for rank in keys} self._num_elements_in_bucket = {rank: 0 for rank in keys} def reset_by_rank(self, reduce_rank=None): - self._grads[reduce_rank] = [] self._params[reduce_rank] = [] self._num_elements_in_bucket[reduce_rank] = 0 def get_grad(self, reduce_rank: int = None): - return self._grads[reduce_rank] + param_list = self.get_param(reduce_rank) + for param in param_list: + # the param must have grad for reduction + assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + return [param.grad for param in param_list] def get_param(self, reduce_rank: int = None): return self._params[reduce_rank] diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 38736d01a..f45b5e200 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): reduce_bucket_size: int = 1024 * 1024, # communication communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 + partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None): @@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) return params_per_rank - ########################################################### - # Backward Reduction Hook - ########################################################### + ########################### + # Backward Reduction Hook # + ########################### + + def _grad_handler(self, param, grad, reduce_rank): + self._add_to_reduction_bucket(param, reduce_rank) + return grad def _attach_reduction_hook(self): # we iterate over the fp16 params @@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): else: reduce_rank = None - def _define_and_attach(param, reduce_rank): - # get the AccumulateGrad object of the param itself - accum_grad_obj = get_grad_accumulate_object(param) - self._grad_store.add_accumulate_grad_object(accum_grad_obj) + param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) - reduction_func = partial(self._reduce_and_remove_grads_by_bucket, - param=param, - reduce_rank=reduce_rank) + def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() - # define hook - # NOT IMPORTANT BUT GOOD TO KNOW: - # args here is not grad, but allow_unreacable and accumulate_grad - def reduce_grad_hook(*args): - reduction_func() + with torch.cuda.stream(stream): + flat = bucket.flatten() + reduce_global_rank = None + if reduce_rank is not None: + reduce_global_rank = self._dp_global_ranks[reduce_rank] + reduced_flat = reduce_tensor_dp_group(tensor=flat, + dtype=self._communication_dtype, + dst_local_rank=reduce_rank, + dst_global_rank=reduce_global_rank, + group=self._dp_torch_group) - accum_grad_obj.register_hook(reduce_grad_hook) + # update the reduced tensor + if reduce_rank is None or reduce_rank == self._local_rank: + bucket.unflatten_and_copy(reduced_flat) - _define_and_attach(param, reduce_rank) + def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): + param_bucket = TensorBucket(size=bucket_size) - def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None): - param_size = param.numel() + for tensor in tensor_list: + param_bucket.add_to_bucket(tensor, allow_oversize=True) - # check if the bucket is full - # if full, will reduce the grads already in the bucket - # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._reduce_grads_in_bucket(reduce_rank) + if param_bucket.is_full_or_oversized(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) + param_bucket.empty() - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + if not param_bucket.is_empty(): + self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' + def _reduce_grads(self, reduce_rank, grads, bucket_size): + grad_buckets_by_dtype = split_half_float_double(grads) - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_grad(param.grad, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + for tensor_list in grad_buckets_by_dtype: + self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, + bucket_size=bucket_size, + reduce_rank=reduce_rank) - def _reduce_grads_in_bucket(self, reduce_rank=None): + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self, reduce_rank=None): # reduce grads - self._reduce_grads_by_rank(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + self._reduce_grads(reduce_rank=reduce_rank, + grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), + bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) # use communication stream if overlapping # communication with computation @@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._bucket_store.reset_by_rank(reduce_rank) - def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_half_float_double(grads) + def _add_to_reduction_bucket(self, param, reduce_rank=None): + param_size = param.numel() - for tensor_list in grad_buckets_by_dtype: - self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank) + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + self._run_reduction(reduce_rank) - ############################## - # Reduction Utility Function # - ############################## - def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) + # the param must not be reduced to ensure correctness + is_param_reduced = self._param_store.is_param_reduced(param) + if is_param_reduced: + msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ + + 'duplicate reduction will lead to arithmetic incorrectness' + raise RuntimeError(msg) - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduce_global_rank = None - if reduce_rank is not None: - reduce_global_rank = self._dp_global_ranks[reduce_rank] - reduced_flat = reduce_tensor_dp_group(tensor=flat, - dtype=self._communication_dtype, - dst_local_rank=reduce_rank, - dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) + self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) + self._bucket_store.add_param(param, reduce_rank) ################################ # torch.optim.Optimizer methods @@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # broadcast the updated model weights handles = [] for group_id in range(self.num_param_groups): - for rank in range(self._world_size): - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) + for index in range(self._world_size): + rank = self._dp_global_ranks[index] + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id) handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True) handles.append(handle) @@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): param_group = self._fp16_param_groups[group_id] for param in param_group: if param.grad is not None: - self._reduce_and_remove_grads_by_bucket(param) + self._add_to_reduction_bucket(param) # we need to reduce the gradients # left in the communication bucket - self._reduce_grads_in_bucket() + self._run_reduction() def _reduce_grad_stage2(self): # when partition_grads is True, reduction hooks @@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # only need to reduce the gradients # left in the communication bucket for reduce_rank in range(self._world_size): - self._reduce_grads_in_bucket(reduce_rank) + self._run_reduction(reduce_rank) diff --git a/examples/language/gpt/experiments/pipeline_parallel/requirements.txt b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/gpt/gemini/requirements.txt b/examples/language/gpt/gemini/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/language/gpt/gemini/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt index e1f131468..ef58bb76b 100644 --- a/examples/language/gpt/requirements.txt +++ b/examples/language/gpt/requirements.txt @@ -1 +1,2 @@ transformers >= 4.23 +colossalai diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt new file mode 100644 index 000000000..137a69e80 --- /dev/null +++ b/examples/language/opt/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py new file mode 100644 index 000000000..0b42722fe --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -0,0 +1,70 @@ +from functools import partial +from typing import Optional, Tuple, Union + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + +HIDDEN_SIZE = 16 + + +class GPT2MLPWithCkpt(nn.Module): + + def __init__(self, intermediate_size, hidden_size): + super().__init__() + embed_dim = hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = checkpoint(self.c_proj, hidden_states) + hidden_states = self.act(hidden_states) + + return hidden_states + + +def check_act_ckpt(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) + input_sample = { + 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + } + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + gm = initialize_model(model, input_sample, device_mesh) + code = gm.module.graph.python_code('self').src + assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code + assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_mlp_layer(): + world_size = 4 + run_func = partial(check_act_ckpt, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mlp_layer() diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_simple_evoformer.py similarity index 66% rename from tests/test_autochunk/benchmark_autochunk.py rename to tests/test_autochunk/benchmark_simple_evoformer.py index 6632ece61..8b5d8a8be 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_simple_evoformer.py @@ -2,14 +2,13 @@ import time import torch import torch.fx +from simple_evoformer import base_evoformer, openfold_evoformer from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import MetaTensor -from tests.test_autochunk.evoformer.evoformer import evoformer_base -from tests.test_autochunk.openfold.evoformer import EvoformerBlock def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): @@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N time2 = time.time() new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - print( - "%s: time %.4fs, mem %dMB" - % (title, (time2 - time1) / loop, new_max_mem - now_mem) - ) + print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem)) def _build_autochunk(model, max_memory, node, pair): @@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair): }, ) - gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace interp = MetaInfoProp(gm_prop) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) # now run it twice to get meta info in graph module, not necessary gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate( - MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") - ) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) # set code_gen codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False) @@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair): return gm -def _build_openfold(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).cuda() - return model - - def benchmark_evoformer(): # init data and model - msa_len = 256 - pair_len = 512 + msa_len = 128 + pair_len = 256 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - model = evoformer_base().cuda() + model = base_evoformer().cuda() # build autochunk model # max_memory = 1000 # MB, fit memory mode - max_memory = None # min memory mode - autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) + max_memory = None # min memory mode + autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair) # build openfold chunk_size = 64 - openfold = _build_openfold() + openfold = openfold_evoformer().cuda() # benchmark _benchmark_evoformer(model, node, pair, "base") diff --git a/tests/test_autochunk/evoformer/evoformer.py b/tests/test_autochunk/evoformer/evoformer.py deleted file mode 100644 index cfd2bb2a2..000000000 --- a/tests/test_autochunk/evoformer/evoformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn - -from .msa import MSAStack -from .ops import OutProductMean -from .triangle import PairStack - - -def print_memory(init_mem, text=None): - now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem - max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem - print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) - torch.cuda.reset_peak_memory_stats() - - -class EvoformerBlock(nn.Module): - - def __init__(self, d_node, d_pair): - super(EvoformerBlock, self).__init__() - - self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) - self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) - self.pair_stack = PairStack(d_pair=d_pair) - - def forward(self, node, pair): - node = self.msa_stack(node, pair) - pair = pair + self.communication(node) - pair = self.pair_stack(pair) - return node, pair - - -class Evoformer(nn.Module): - - def __init__(self, d_node, d_pair): - super(Evoformer, self).__init__() - - self.blocks = nn.ModuleList() - for _ in range(1): - self.blocks.append(EvoformerBlock(d_node, d_pair)) - - def forward(self, node, pair): - for b in self.blocks: - node, pair = b(node, pair) - return node, pair - - -def evoformer_tiny(): - return Evoformer(d_node=64, d_pair=32) - - -def evoformer_base(): - return Evoformer(d_node=256, d_pair=128) - - -def evoformer_large(): - return Evoformer(d_node=512, d_pair=256) - - -__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/tests/test_autochunk/evoformer/initializer.py b/tests/test_autochunk/evoformer/initializer.py deleted file mode 100755 index c6ce0659e..000000000 --- a/tests/test_autochunk/evoformer/initializer.py +++ /dev/null @@ -1,29 +0,0 @@ -import math - -import numpy as np -import torch.nn as nn - - -def glorot_uniform_af(x, gain=1.0): - """ - initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: - In PyTorch: - [feature_out, feature_in, n_head ...] - In Jax: - [... n_head, feature_in, feature_out] - However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: - [feature_in, n_head, feature_out] - - In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors - """ - fan_in, fan_out = x.shape[-2:] - if len(x.shape) > 2: - receptive_field_size = np.prod(x.shape[:-2]) - fan_in *= receptive_field_size - fan_out *= receptive_field_size - std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation - - nn.init.uniform_(x, -dev, dev) - - return x diff --git a/tests/test_autochunk/evoformer/kernel.py b/tests/test_autochunk/evoformer/kernel.py deleted file mode 100644 index 26ab5dc53..000000000 --- a/tests/test_autochunk/evoformer/kernel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn.functional as F - - -def bias_sigmod_ele(y, bias, z): - return torch.sigmoid(y + bias) * z - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, - residual: torch.Tensor, prob: float) -> torch.Tensor: - out = (x + bias) * F.dropout(dropmask, p=prob, training=False) - out = residual + out - return out - - -def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, - dropout_mask: torch.Tensor, Z_raw: torch.Tensor, - prob: float) -> torch.Tensor: - return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/tests/test_autochunk/evoformer/msa.py b/tests/test_autochunk/evoformer/msa.py deleted file mode 100644 index cac456638..000000000 --- a/tests/test_autochunk/evoformer/msa.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add -from .ops import SelfAttention, Transition - - -class MSARowAttentionWithPairBias(nn.Module): - - def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): - super(MSARowAttentionWithPairBias, self).__init__() - self.d_node = d_node - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernormM = LayerNorm(d_node) - self.layernormZ = LayerNorm(d_pair) - - _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) - - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) - - def forward(self, M_raw, Z): - ## Input projections - M = self.layernormM(M_raw) - Z = self.layernormZ(Z) - b = F.linear(Z, self.linear_b_weights) - b = b.permute(0, 3, 1, 2) - # b = rearrange(b, 'b q k h -> b h q k') - - M = self.attention(M, b) - dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype) - - return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) - - -class MSAColumnAttention(nn.Module): - - def __init__(self, d_node, c=32, n_head=8): - super(MSAColumnAttention, self).__init__() - self.d_node = d_node - self.c = c - self.n_head = n_head - - self.layernormM = LayerNorm(d_node) - self.attention = SelfAttention(qkv_dim=d_node, - c=c, - n_head=n_head, - out_dim=d_node, - gating=True) - - def forward(self, M_raw): - M = M_raw.transpose(-2, -3) - M = self.layernormM(M) - - M = self.attention(M) - - M = M.transpose(-2, -3) - return M_raw + M - - -class MSAStack(nn.Module): - - def __init__(self, d_node, d_pair, p_drop=0.15): - super(MSAStack, self).__init__() - - self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, - d_pair=d_pair, - p_drop=p_drop) - - self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) - self.MSATransition = Transition(d=d_node) - - def forward(self, node, pair): - node = self.MSARowAttentionWithPairBias(node, pair) - node = self.MSAColumnAttention(node) - node = self.MSATransition(node) - - return node diff --git a/tests/test_autochunk/evoformer/ops.py b/tests/test_autochunk/evoformer/ops.py deleted file mode 100755 index a56057522..000000000 --- a/tests/test_autochunk/evoformer/ops.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn import LayerNorm - -from .initializer import glorot_uniform_af -from .kernel import bias_sigmod_ele - - -class DropoutRowwise(nn.Module): - - def __init__(self, p): - super(DropoutRowwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, 0:1, :, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class DropoutColumnwise(nn.Module): - - def __init__(self, p): - super(DropoutColumnwise, self).__init__() - self.p = p - self.dropout = nn.Dropout(p=p) - - def forward(self, x): - dropout_mask = torch.ones_like(x[:, :, 0:1, :]) - dropout_mask = self.dropout(dropout_mask) - return dropout_mask * x - - -class Transition(nn.Module): - - def __init__(self, d, n=4): - super(Transition, self).__init__() - self.norm = LayerNorm(d) - self.linear1 = Linear(d, n * d, initializer='relu') - self.linear2 = Linear(n * d, d, initializer='zeros') - - def forward(self, src): - x = self.norm(src) - x = self.linear2(F.relu(self.linear1(x))) - return src + x - - -class OutProductMean(nn.Module): - - def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): - super(OutProductMean, self).__init__() - - self.layernormM = LayerNorm(n_feat) - self.linear_a = Linear(n_feat, n_feat_proj) - self.linear_b = Linear(n_feat, n_feat_proj) - - self.o_linear = Linear(n_feat_proj * n_feat_proj, - n_feat_out, - initializer='zero', - use_bias=True) - - def forward(self, M): - M = self.layernormM(M) - left_act = self.linear_a(M) - right_act = self.linear_b(M) - - o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() - # O = rearrange(O, 'b i j d e -> b i j (d e)') - o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1) - Z = self.o_linear(o) - - return Z - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - feature_in: int, - feature_out: int, - initializer: str = 'linear', - use_bias: bool = True, - bias_init: float = 0., - ): - super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) - - self.use_bias = use_bias - if initializer == 'linear': - glorot_uniform_af(self.weight, gain=1.0) - elif initializer == 'relu': - glorot_uniform_af(self.weight, gain=2.0) - elif initializer == 'zeros': - nn.init.zeros_(self.weight) - if self.use_bias: - with torch.no_grad(): - self.bias.fill_(bias_init) - - -class SelfAttention(nn.Module): - """ - Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors - """ - - def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): - super(SelfAttention, self).__init__() - self.qkv_dim = qkv_dim - self.c = c - self.n_head = n_head - self.out_dim = out_dim - self.gating = gating - self.last_bias_fuse = last_bias_fuse - - self.scaling = self.c**(-0.5) - - # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') - self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) - - if gating: - self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) - self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) - - self.o_linear = Linear(n_head * c, - out_dim, - initializer='zero', - use_bias=(not last_bias_fuse)) - - def forward(self, in_data, nonbatched_bias=None): - """ - :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] - :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] - :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] - """ - - # qkv = self.to_qkv(in_data).chunk(3, dim=-1) - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) - - q = self.to_q(in_data) - k = self.to_k(in_data) - v = self.to_v(in_data) - - # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), - # [q, k, v]) - q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), - [q, k, v]) - - q = q * self.scaling - - logits = torch.matmul(q, k.transpose(-1, -2)) - - if nonbatched_bias is not None: - logits += nonbatched_bias.unsqueeze(1) - weights = torch.softmax(logits, dim=-1) - # weights = softmax(logits) - - weighted_avg = torch.matmul(weights, v) - # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') - weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) - weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) - - if self.gating: - gate_values = self.gating_linear(in_data) - weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) - - output = self.o_linear(weighted_avg) - return output diff --git a/tests/test_autochunk/evoformer/triangle.py b/tests/test_autochunk/evoformer/triangle.py deleted file mode 100644 index f479469c3..000000000 --- a/tests/test_autochunk/evoformer/triangle.py +++ /dev/null @@ -1,192 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import LayerNorm - -from .kernel import bias_dropout_add, bias_ele_dropout_residual -from .ops import Linear, SelfAttention, Transition - - -def permute_final_dims(tensor, inds): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -class TriangleMultiplicationOutgoing(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationOutgoing, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 0, 1)), - # permute_final_dims(right_proj_act, (2, 1, 0)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleMultiplicationIncoming(nn.Module): - - def __init__(self, d_pair, p_drop, c=128): - super(TriangleMultiplicationIncoming, self).__init__() - self.d_pair = d_pair - self.c = c - - self.layernorm1 = LayerNorm(d_pair) - self.left_projection = Linear(d_pair, c) - self.right_projection = Linear(d_pair, c) - self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) - - self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) - self.layernorm2 = LayerNorm(c) - self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) - self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - self.p_drop = p_drop - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - left_proj_act = self.left_projection(Z) - right_proj_act = self.right_projection(Z) - - left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) - right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) - - g = torch.sigmoid(self.output_gate(Z)) - # p = torch.matmul( - # permute_final_dims(left_proj_act, (2, 1, 0)), - # permute_final_dims(right_proj_act, (2, 0, 1)), - # ) - # ab = permute_final_dims(p, (1, 2, 0)) - - ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) - ab = self.output_projection(self.layernorm2(ab)) - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_ele_dropout_residual(ab, - self.output_bias, - g, - dropout_mask, - Z_raw, - prob=self.p_drop) - - -class TriangleAttentionStartingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionStartingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = self.layernorm1(Z_raw) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class TriangleAttentionEndingNode(nn.Module): - - def __init__(self, d_pair, p_drop, c=32, n_head=4): - super(TriangleAttentionEndingNode, self).__init__() - self.d_pair = d_pair - self.c = c - self.n_head = n_head - self.p_drop = p_drop - - self.layernorm1 = LayerNorm(d_pair) - _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), - std=1.0 / math.sqrt(d_pair)) - self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) - self.attention = SelfAttention(qkv_dim=d_pair, - c=c, - n_head=n_head, - out_dim=d_pair, - gating=True, - last_bias_fuse=True) - - self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) - - def forward(self, Z_raw): - Z = Z_raw.transpose(-2, -3) - Z = self.layernorm1(Z) - b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) - - Z = self.attention(Z, b) - - Z = Z.transpose(-2, -3) - dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype) - return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) - - -class PairStack(nn.Module): - - def __init__(self, d_pair, p_drop=0.25): - super(PairStack, self).__init__() - - self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) - self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) - self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) - self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) - self.PairTransition = Transition(d=d_pair) - - def forward(self, pair): - pair = self.TriangleMultiplicationOutgoing(pair) - pair = self.TriangleMultiplicationIncoming(pair) - pair = self.TriangleAttentionStartingNode(pair) - pair = self.TriangleAttentionEndingNode(pair) - pair = self.PairTransition(pair) - return pair diff --git a/tests/test_autochunk/openfold/checkpointing.py b/tests/test_autochunk/openfold/checkpointing.py deleted file mode 100644 index 83e77c638..000000000 --- a/tests/test_autochunk/openfold/checkpointing.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.utils.checkpoint -from typing import Any, Tuple, List, Callable, Optional - - -BLOCK_ARG = Any -BLOCK_ARGS = List[BLOCK_ARG] - - -def get_checkpoint_fn(): - checkpoint = torch.utils.checkpoint.checkpoint - - return checkpoint - - -@torch.jit.ignore -def checkpoint_blocks( - blocks: List[Callable], - args: BLOCK_ARGS, - blocks_per_ckpt: Optional[int], -) -> BLOCK_ARGS: - """ - Chunk a list of blocks and run each chunk with activation - checkpointing. We define a "block" as a callable whose only inputs are - the outputs of the previous block. - - Implements Subsection 1.11.8 - - Args: - blocks: - List of blocks - args: - Tuple of arguments for the first block. - blocks_per_ckpt: - Size of each chunk. A higher value corresponds to fewer - checkpoints, and trades memory for speed. If None, no checkpointing - is performed. - Returns: - The output of the final block - """ - def wrap(a): - return (a,) if type(a) is not tuple else a - - def exec(b, a): - for block in b: - a = wrap(block(*a)) - return a - - def chunker(s, e): - def exec_sliced(*a): - return exec(blocks[s:e], a) - - return exec_sliced - - # Avoids mishaps when the blocks take just one argument - args = wrap(args) - - if blocks_per_ckpt is None: - return exec(blocks, args) - elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): - raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") - - checkpoint = get_checkpoint_fn() - - for s in range(0, len(blocks), blocks_per_ckpt): - e = s + blocks_per_ckpt - args = checkpoint(chunker(s, e), *args) - args = wrap(args) - - return args diff --git a/tests/test_autochunk/openfold/dropout.py b/tests/test_autochunk/openfold/dropout.py deleted file mode 100644 index 651b9775e..000000000 --- a/tests/test_autochunk/openfold/dropout.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -from functools import partialmethod -from typing import Union, List - - -class Dropout(nn.Module): - """ - Implementation of dropout with the ability to share the dropout mask - along a particular dimension. - - If not in training mode, this module computes the identity function. - """ - - def __init__(self, r: float, batch_dim: Union[int, List[int]]): - """ - Args: - r: - Dropout rate - batch_dim: - Dimension(s) along which the dropout mask is shared - """ - super(Dropout, self).__init__() - - self.r = r - if type(batch_dim) == int: - batch_dim = [batch_dim] - self.batch_dim = batch_dim - self.dropout = nn.Dropout(self.r) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Tensor to which dropout is applied. Can have any shape - compatible with self.batch_dim - """ - shape = list(x.shape) - if self.batch_dim is not None: - for bd in self.batch_dim: - shape[bd] = 1 - mask = x.new_ones(shape) - mask = self.dropout(mask) - x *= mask - return x - - -class DropoutRowwise(Dropout): - """ - Convenience class for rowwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-3) - - -class DropoutColumnwise(Dropout): - """ - Convenience class for columnwise dropout as described in subsection - 1.11.6. - """ - - __init__ = partialmethod(Dropout.__init__, batch_dim=-2) diff --git a/tests/test_autochunk/openfold/evoformer.py b/tests/test_autochunk/openfold/evoformer.py deleted file mode 100644 index b53ec1aa5..000000000 --- a/tests/test_autochunk/openfold/evoformer.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import torch -import torch.nn as nn -from typing import Tuple, Optional -from functools import partial - -from .primitives import Linear, LayerNorm -from .dropout import DropoutRowwise, DropoutColumnwise -from .msa import ( - MSARowAttentionWithPairBias, - MSAColumnAttention, - MSAColumnGlobalAttention, -) -from .outer_product_mean import OuterProductMean -from .pair_transition import PairTransition -from .triangular_attention import ( - TriangleAttentionStartingNode, - TriangleAttentionEndingNode, -) -from .triangular_multiplicative_update import ( - TriangleMultiplicationOutgoing, - TriangleMultiplicationIncoming, -) -from .checkpointing import checkpoint_blocks, get_checkpoint_fn -from .tensor_utils import chunk_layer - - -class MSATransition(nn.Module): - """ - Feed-forward network applied to MSA activations after attention. - - Implements Algorithm 9 - """ - def __init__(self, c_m, n): - """ - Args: - c_m: - MSA channel dimension - n: - Factor multiplied to c_m to obtain the hidden channel - dimension - """ - super(MSATransition, self).__init__() - - self.c_m = c_m - self.n = n - - self.layer_norm = LayerNorm(self.c_m) - self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") - - def _transition(self, m, mask): - m = self.linear_1(m) - m = self.relu(m) - m = self.linear_2(m) * mask - return m - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - {"m": m, "mask": mask}, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA activation - mask: - [*, N_seq, N_res, C_m] MSA mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA activation update - """ - - # DISCREPANCY: DeepMind forgets to apply the MSA mask here. - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, 1] - mask = mask.unsqueeze(-1) - - m = self.layer_norm(m) - - if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) - else: - m = self._transition(m, mask) - - return m - - -class EvoformerBlockCore(nn.Module): - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - pair_dropout: float, - inf: float, - eps: float, - _is_extra_msa_stack: bool = False, - is_multimer: bool = False, - ): - super(EvoformerBlockCore, self).__init__() - self.is_multimer = is_multimer - self.msa_transition = MSATransition( - c_m=c_m, - n=transition_n, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - - self.tri_mul_out = TriangleMultiplicationOutgoing( - c_z, - c_hidden_mul, - ) - self.tri_mul_in = TriangleMultiplicationIncoming( - c_z, - c_hidden_mul, - ) - - self.tri_att_start = TriangleAttentionStartingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - self.tri_att_end = TriangleAttentionEndingNode( - c_z, - c_hidden_pair_att, - no_heads_pair, - inf=inf, - ) - - self.pair_transition = PairTransition( - c_z, - transition_n, - ) - - self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) - self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) - - def forward( - self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # DeepMind doesn't mask these transitions in the source, so _mask_trans - # should be disabled to better approximate the exact activations of - # the original. - - m = m + self.msa_transition( - m, chunk_size=chunk_size - ) - z = z + self.outer_product_mean( - m, chunk_size=chunk_size - ) - z = z + self.ps_dropout_row_layer(self.tri_mul_out(z)) - z = z + self.ps_dropout_row_layer(self.tri_mul_in(z)) - z = z + self.ps_dropout_row_layer( - self.tri_att_start(z, chunk_size=chunk_size) - ) - z = z + self.ps_dropout_col_layer( - self.tri_att_end(z, chunk_size=chunk_size) - ) - z = z + self.pair_transition( - z, chunk_size=chunk_size - ) - - return m, z - - -class EvoformerBlock(nn.Module): - def __init__(self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - no_heads_msa: int, - no_heads_pair: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - inf: float, - eps: float, - is_multimer: bool, - ): - super(EvoformerBlock, self).__init__() - - self.msa_att_row = MSARowAttentionWithPairBias( - c_m=c_m, - c_z=c_z, - c_hidden=c_hidden_msa_att, - no_heads=no_heads_msa, - inf=inf, - ) - - self.msa_att_col = MSAColumnAttention( - c_m, - c_hidden_msa_att, - no_heads_msa, - inf=inf, - ) - - self.msa_dropout_layer = DropoutRowwise(msa_dropout) - - self.core = EvoformerBlockCore( - c_m=c_m, - c_z=c_z, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - ) - - self.outer_product_mean = OuterProductMean( - c_m, - c_z, - c_hidden_opm, - ) - self.is_multimer = is_multimer - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m = m + self.msa_dropout_layer( - self.msa_att_row(m, z=z, chunk_size=chunk_size) - ) - m = m + self.msa_att_col(m, chunk_size=chunk_size) - m, z = self.core( - m, - z, - chunk_size=chunk_size, - ) - - return m, z - - -class EvoformerStack(nn.Module): - """ - Main Evoformer trunk. - - Implements Algorithm 6. - """ - - def __init__( - self, - c_m: int, - c_z: int, - c_hidden_msa_att: int, - c_hidden_opm: int, - c_hidden_mul: int, - c_hidden_pair_att: int, - c_s: int, - no_heads_msa: int, - no_heads_pair: int, - no_blocks: int, - transition_n: int, - msa_dropout: float, - pair_dropout: float, - blocks_per_ckpt: int, - inf: float, - eps: float, - clear_cache_between_blocks: bool = False, - is_multimer: bool = False, - **kwargs, - ): - """ - Args: - c_m: - MSA channel dimension - c_z: - Pair channel dimension - c_hidden_msa_att: - Hidden dimension in MSA attention - c_hidden_opm: - Hidden dimension in outer product mean module - c_hidden_mul: - Hidden dimension in multiplicative updates - c_hidden_pair_att: - Hidden dimension in triangular attention - c_s: - Channel dimension of the output "single" embedding - no_heads_msa: - Number of heads used for MSA attention - no_heads_pair: - Number of heads used for pair attention - no_blocks: - Number of Evoformer blocks in the stack - transition_n: - Factor by which to multiply c_m to obtain the MSATransition - hidden dimension - msa_dropout: - Dropout rate for MSA activations - pair_dropout: - Dropout used for pair activations - blocks_per_ckpt: - Number of Evoformer blocks in each activation checkpoint - clear_cache_between_blocks: - Whether to clear CUDA's GPU memory cache between blocks of the - stack. Slows down each block but can reduce fragmentation - """ - super(EvoformerStack, self).__init__() - - self.blocks_per_ckpt = blocks_per_ckpt - self.clear_cache_between_blocks = clear_cache_between_blocks - - self.blocks = nn.ModuleList() - - for _ in range(no_blocks): - block = EvoformerBlock( - c_m=c_m, - c_z=c_z, - c_hidden_msa_att=c_hidden_msa_att, - c_hidden_opm=c_hidden_opm, - c_hidden_mul=c_hidden_mul, - c_hidden_pair_att=c_hidden_pair_att, - no_heads_msa=no_heads_msa, - no_heads_pair=no_heads_pair, - transition_n=transition_n, - msa_dropout=msa_dropout, - pair_dropout=pair_dropout, - inf=inf, - eps=eps, - is_multimer=is_multimer, - ) - self.blocks.append(block) - - self.linear = Linear(c_m, c_s) - - def forward(self, - m: torch.Tensor, - z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, - chunk_size: int, - _mask_trans: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - msa_mask: - [*, N_seq, N_res] MSA mask - pair_mask: - [*, N_res, N_res] pair mask - Returns: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding - s: - [*, N_res, C_s] single embedding (or None if extra MSA stack) - """ - blocks = [ - partial( - b, - msa_mask=msa_mask, - pair_mask=pair_mask, - chunk_size=chunk_size, - _mask_trans=_mask_trans, - ) - for b in self.blocks - ] - - if(self.clear_cache_between_blocks): - def block_with_cache_clear(block, *args): - torch.cuda.empty_cache() - return block(*args) - - blocks = [partial(block_with_cache_clear, b) for b in blocks] - - m, z = checkpoint_blocks( - blocks, - args=(m, z), - blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, - ) - - s = self.linear(m[..., 0, :, :]) - - return m, z, s diff --git a/tests/test_autochunk/openfold/msa.py b/tests/test_autochunk/openfold/msa.py deleted file mode 100644 index 7c137286f..000000000 --- a/tests/test_autochunk/openfold/msa.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import torch -import torch.nn as nn -from typing import Optional, List, Tuple - -from .primitives import ( - Linear, - LayerNorm, - Attention, - GlobalAttention, - _attention_chunked_trainable, -) -from .checkpointing import get_checkpoint_fn -from .tensor_utils import ( - chunk_layer, - permute_final_dims, - flatten_final_dims, -) - - -class MSAAttention(nn.Module): - def __init__( - self, - c_in, - c_hidden, - no_heads, - pair_bias=False, - c_z=None, - inf=1e9, - ): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - pair_bias: - Whether to use pair embedding bias - c_z: - Pair embedding channel dimension. Ignored unless pair_bias - is true - inf: - A large number to be used in computing the attention mask - """ - super(MSAAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.pair_bias = pair_bias - self.c_z = c_z - self.inf = inf - - self.layer_norm_m = LayerNorm(self.c_in) - - self.layer_norm_z = None - self.linear_z = None - if self.pair_bias: - self.layer_norm_z = LayerNorm(self.c_z) - self.linear_z = Linear( - self.c_z, self.no_heads, bias=False, init="normal" - ) - - self.mha = Attention( - self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads - ) - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self.mha, - {"q_x": m, "kv_x": m, "biases": biases}, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def _prep_inputs(self, - m: torch.Tensor, - z: Optional[torch.Tensor], - mask: Optional[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, N_seq, N_res, C_m] - m = self.layer_norm_m(m) - - n_seq, n_res = m.shape[-3:-1] - if mask is None: - # [*, N_seq, N_res] - mask = m.new_ones( - m.shape[:-3] + (n_seq, n_res), - ) - - # [*, N_seq, 1, 1, N_res] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # This step simply returns a larger view of the bias, and does not - # consume additional memory. - # [*, N_seq, no_heads, N_res, N_res] - #bias = bias.expand( - # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) - #) - - if (self.pair_bias and - z is not None and # For the - self.layer_norm_z is not None and # benefit of - self.linear_z is not None # TorchScript - ): - # [*, N_res, N_res, C_z] - z = self.layer_norm_z(z) - - # [*, N_res, N_res, no_heads] - z = self.linear_z(z) - - # [*, 1, no_heads, N_res, N_res] - z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) - - return m, mask_bias, z - - - def forward(self, - m: torch.Tensor, - z: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - _chunk_logits: Optional[int] = None, - _checkpoint_chunks: Optional[bool] = None, - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - z: - [*, N_res, N_res, C_z] pair embedding. Required only if - pair_bias is True - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - - """ - m, mask_bias, z = self._prep_inputs(m, z, mask) - - biases = [mask_bias] - if(z is not None): - biases.append(z) - - if chunk_size is not None: - m = self._chunk(m, biases, chunk_size) - else: - m = self.mha( - q_x=m, - kv_x=m, - biases=biases - ) - - return m - - -class MSARowAttentionWithPairBias(MSAAttention): - """ - Implements Algorithm 7. - """ - - def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - Input channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSARowAttentionWithPairBias, self).__init__( - c_m, - c_hidden, - no_heads, - pair_bias=True, - c_z=c_z, - inf=inf, - ) - - -class MSAColumnAttention(nn.Module): - """ - Implements Algorithm 8. - - By rights, this should also be a subclass of MSAAttention. Alas, - most inheritance isn't supported by TorchScript. - """ - - def __init__(self, c_m, c_hidden, no_heads, inf=1e9): - """ - Args: - c_m: - MSA channel dimension - c_hidden: - Per-head hidden channel dimension - no_heads: - Number of attention heads - inf: - Large number used to construct attention masks - """ - super(MSAColumnAttention, self).__init__() - - self.c_m = c_m - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - - self._msa_att = MSAAttention( - c_in=c_m, - c_hidden=c_hidden, - no_heads=no_heads, - pair_bias=False, - c_z=None, - inf=inf, - ) - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - chunk_size: - Size of chunks into which the inputs are split along their - batch dimensions. A low value decreases memory overhead at the - cost of slower execution. Chunking is not performed by default. - """ - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - - m = self._msa_att(m, chunk_size=chunk_size) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - - return m - - -class MSAColumnGlobalAttention(nn.Module): - def __init__( - self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, - ): - super(MSAColumnGlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.layer_norm_m = nn.LayerNorm(c_in) - - self.global_attention = GlobalAttention( - c_in=c_in, - c_hidden=c_hidden, - no_heads=no_heads, - inf=inf, - eps=eps, - ) - - @torch.jit.ignore - def _chunk(self, - m: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - mha_input = { - "m": m, - } - return chunk_layer( - self.global_attention, - mha_input, - chunk_size=chunk_size, - no_batch_dims=len(m.shape[:-2]), - ) - - def forward( - self, - m: torch.Tensor, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - n_seq, n_res, c_in = m.shape[-3:] - - # [*, N_res, N_seq, C_in] - m = m.transpose(-2, -3) - - # [*, N_res, N_seq, C_in] - m = self.layer_norm_m(m) - - if chunk_size is not None: - m = self._chunk(m, chunk_size) - else: - m = self.global_attention(m=m) - - # [*, N_seq, N_res, C_in] - m = m.transpose(-2, -3) - - return m diff --git a/tests/test_autochunk/openfold/outer_product_mean.py b/tests/test_autochunk/openfold/outer_product_mean.py deleted file mode 100644 index daadf1c27..000000000 --- a/tests/test_autochunk/openfold/outer_product_mean.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear -from .tensor_utils import chunk_layer - - -class OuterProductMean(nn.Module): - """ - Implements Algorithm 10. - """ - - def __init__(self, c_m, c_z, c_hidden, eps=1e-3): - """ - Args: - c_m: - MSA embedding channel dimension - c_z: - Pair embedding channel dimension - c_hidden: - Hidden channel dimension - """ - super(OuterProductMean, self).__init__() - - self.c_m = c_m - self.c_z = c_z - self.c_hidden = c_hidden - self.eps = eps - - self.layer_norm = nn.LayerNorm(c_m) - self.linear_1 = Linear(c_m, c_hidden) - self.linear_2 = Linear(c_m, c_hidden) - self.linear_out = Linear(c_hidden ** 2, c_z, init="final") - - def _opm(self, a, b): - # [*, N_res, N_res, C, C] - outer = torch.einsum("...bac,...dae->...bdce", a, b) - - # [*, N_res, N_res, C * C] - outer = outer.reshape(outer.shape[:-2] + (-1,)) - - # [*, N_res, N_res, C_z] - outer = self.linear_out(outer) - - return outer - - @torch.jit.ignore - def _chunk(self, - a: torch.Tensor, - b: torch.Tensor, - chunk_size: int - ) -> torch.Tensor: - # Since the "batch dim" in this case is not a true batch dimension - # (in that the shape of the output depends on it), we need to - # iterate over it ourselves - a_reshape = a.reshape((-1,) + a.shape[-3:]) - b_reshape = b.reshape((-1,) + b.shape[-3:]) - out = [] - for a_prime, b_prime in zip(a_reshape, b_reshape): - outer = chunk_layer( - partial(self._opm, b=b_prime), - {"a": a_prime}, - chunk_size=chunk_size, - no_batch_dims=1, - ) - out.append(outer) - outer = torch.stack(out, dim=0) - outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) - - return outer - - def forward(self, - m: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - m: - [*, N_seq, N_res, C_m] MSA embedding - mask: - [*, N_seq, N_res] MSA mask - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - if mask is None: - mask = m.new_ones(m.shape[:-1]) - - # [*, N_seq, N_res, C_m] - m = self.layer_norm(m) - - # [*, N_seq, N_res, C] - mask = mask.unsqueeze(-1) - a = self.linear_1(m) * mask - b = self.linear_2(m) * mask - - a = a.transpose(-2, -3) - b = b.transpose(-2, -3) - - if chunk_size is not None: - outer = self._chunk(a, b, chunk_size) - else: - outer = self._opm(a, b) - - # [*, N_res, N_res, 1] - norm = torch.einsum("...abc,...adc->...bdc", mask, mask) - - # [*, N_res, N_res, C_z] - outer = outer / (self.eps + norm) - - return outer diff --git a/tests/test_autochunk/openfold/pair_transition.py b/tests/test_autochunk/openfold/pair_transition.py deleted file mode 100644 index 7d09914dc..000000000 --- a/tests/test_autochunk/openfold/pair_transition.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm -from .tensor_utils import chunk_layer - - -class PairTransition(nn.Module): - """ - Implements Algorithm 15. - """ - - def __init__(self, c_z, n): - """ - Args: - c_z: - Pair transition channel dimension - n: - Factor by which c_z is multiplied to obtain hidden channel - dimension - """ - super(PairTransition, self).__init__() - - self.c_z = c_z - self.n = n - - self.layer_norm = LayerNorm(self.c_z) - self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") - self.relu = nn.ReLU() - self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") - - def _transition(self, z, mask): - # [*, N_res, N_res, C_hidden] - z = self.linear_1(z) - z = self.relu(z) - - # [*, N_res, N_res, C_z] - z = self.linear_2(z) * mask - - return z - - @torch.jit.ignore - def _chunk(self, - z: torch.Tensor, - mask: torch.Tensor, - chunk_size: int, - ) -> torch.Tensor: - return chunk_layer( - self._transition, - {"z": z, "mask": mask}, - chunk_size=chunk_size, - no_batch_dims=len(z.shape[:-2]), - ) - - - def forward(self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - z: - [*, N_res, N_res, C_z] pair embedding - Returns: - [*, N_res, N_res, C_z] pair embedding update - """ - # DISCREPANCY: DeepMind forgets to apply the mask in this module. - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - # [*, N_res, N_res, 1] - mask = mask.unsqueeze(-1) - - # [*, N_res, N_res, C_z] - z = self.layer_norm(z) - - if chunk_size is not None: - z = self._chunk(z, mask, chunk_size) - else: - z = self._transition(z=z, mask=mask) - - return z diff --git a/tests/test_autochunk/openfold/primitives.py b/tests/test_autochunk/openfold/primitives.py deleted file mode 100644 index 32a9d487c..000000000 --- a/tests/test_autochunk/openfold/primitives.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np - -import torch -import torch.nn as nn - -from .checkpointing import get_checkpoint_fn -from .tensor_utils import ( - permute_final_dims, - flatten_final_dims, - _chunk_slice, -) - - -def _prod(nums): - out = 1 - for n in nums: - out = out * n - return out - - -def _calculate_fan(linear_weight_shape, fan="fan_in"): - fan_out, fan_in = linear_weight_shape - - if fan == "fan_in": - f = fan_in - elif fan == "fan_out": - f = fan_out - elif fan == "fan_avg": - f = (fan_in + fan_out) / 2 - else: - raise ValueError("Invalid fan option") - - return f - - -def glorot_uniform_init_(weights): - nn.init.xavier_uniform_(weights, gain=1) - - -def final_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def gating_init_(weights): - with torch.no_grad(): - weights.fill_(0.0) - - -def normal_init_(weights): - torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") - - -def ipa_point_weights_init_(weights): - with torch.no_grad(): - softplus_inverse_1 = 0.541324854612918 - weights.fill_(softplus_inverse_1) - - -class Linear(nn.Linear): - """ - A Linear layer with built-in nonstandard initializations. Called just - like torch.nn.Linear. - - Implements the initializers in 1.11.4, plus some additional ones found - in the code. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - bias: bool = True, - init: str = "default", - init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, - ): - """ - Args: - in_dim: - The final dimension of inputs to the layer - out_dim: - The final dimension of layer outputs - bias: - Whether to learn an additive bias. True by default - init: - The initializer to use. Choose from: - - "default": LeCun fan-in truncated normal initialization - "relu": He initialization w/ truncated normal distribution - "glorot": Fan-average Glorot uniform initialization - "gating": Weights=0, Bias=1 - "normal": Normal initialization with std=1/sqrt(fan_in) - "final": Weights=0, Bias=0 - - Overridden by init_fn if the latter is not None. - init_fn: - A custom initializer taking weight and bias as inputs. - Overrides init if not None. - """ - super(Linear, self).__init__(in_dim, out_dim, bias=bias) - - if bias: - with torch.no_grad(): - self.bias.fill_(0) - - if init_fn is not None: - init_fn(self.weight, self.bias) - else: - if init == "default": - normal_init_(self.weight) - elif init == "relu": - normal_init_(self.weight) - elif init == "glorot": - glorot_uniform_init_(self.weight) - elif init == "gating": - gating_init_(self.weight) - if bias: - with torch.no_grad(): - self.bias.fill_(1.0) - elif init == "normal": - normal_init_(self.weight) - elif init == "final": - final_init_(self.weight) - else: - raise ValueError("Invalid init string.") - - -class LayerNorm(nn.Module): - - def __init__(self, c_in, eps=1e-5): - super(LayerNorm, self).__init__() - - self.c_in = (c_in,) - self.eps = eps - - self.weight = nn.Parameter(torch.ones(c_in)) - self.bias = nn.Parameter(torch.zeros(c_in)) - - def forward(self, x): - out = nn.functional.layer_norm( - x, - self.c_in, - self.weight, - self.bias, - self.eps, - ) - - return out - - -@torch.jit.ignore -def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - s = torch.nn.functional.softmax(t, dim=dim) - - return s - - -#@torch.jit.script -def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - biases: List[torch.Tensor]) -> torch.Tensor: - # [*, H, Q, C_hidden] - query = permute_final_dims(query, (1, 0, 2)) - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 2, 0)) - - # [*, H, V, C_hidden] - value = permute_final_dims(value, (1, 0, 2)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - - for b in biases: - a += b - - a = softmax(a, -1) - - # [*, H, Q, C_hidden] - a = torch.matmul(a, value) - - # [*, Q, H, C_hidden] - a = a.transpose(-2, -3) - - return a - - -@torch.jit.ignore -def _attention_chunked_trainable( - query, - key, - value, - biases, - chunk_size, - chunk_dim, - checkpoint, -): - if (checkpoint and len(biases) > 2): - raise ValueError("Checkpointed version permits only permits two bias terms") - - def _checkpointable_attention(q, k, v, b1, b2): - bs = [b for b in [b1, b2] if b is not None] - return _attention(q, k, v, bs) - - o_chunks = [] - checkpoint_fn = get_checkpoint_fn() - count = query.shape[chunk_dim] - for start in range(0, count, chunk_size): - end = start + chunk_size - idx = [slice(None)] * len(query.shape) - idx[chunk_dim] = slice(start, end) - idx_tup = tuple(idx) - q_chunk = query[idx_tup] - k_chunk = key[idx_tup] - v_chunk = value[idx_tup] - - def _slice_bias(b): - idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) - return b[tuple(idx)] - - if (checkpoint): - bias_1_chunk, bias_2_chunk = [ - _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] - ] - - o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, - bias_1_chunk, bias_2_chunk) - else: - bias_chunks = [_slice_bias(b) for b in biases] - - o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) - - o_chunks.append(o_chunk) - - o = torch.cat(o_chunks, dim=chunk_dim) - return o - - -class Attention(nn.Module): - """ - Standard multi-head attention using AlphaFold's default layer - initialization. Allows multiple bias vectors. - """ - - def __init__( - self, - c_q: int, - c_k: int, - c_v: int, - c_hidden: int, - no_heads: int, - gating: bool = True, - ): - """ - Args: - c_q: - Input dimension of query data - c_k: - Input dimension of key data - c_v: - Input dimension of value data - c_hidden: - Per-head hidden dimension - no_heads: - Number of attention heads - gating: - Whether the output should be gated using query data - """ - super(Attention, self).__init__() - - self.c_q = c_q - self.c_k = c_k - self.c_v = c_v - self.c_hidden = c_hidden - self.no_heads = no_heads - self.gating = gating - - # DISCREPANCY: c_hidden is not the per-head channel dimension, as - # stated in the supplement, but the overall channel dimension. - - self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") - self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") - - self.linear_g = None - if self.gating: - self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") - - self.sigmoid = nn.Sigmoid() - - def _prep_qkv(self, q_x: torch.Tensor, - kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # [*, Q/K/V, H * C_hidden] - q = self.linear_q(q_x) - k = self.linear_k(kv_x) - v = self.linear_v(kv_x) - - # [*, Q/K, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - k = k.view(k.shape[:-1] + (self.no_heads, -1)) - v = v.view(v.shape[:-1] + (self.no_heads, -1)) - - q /= math.sqrt(self.c_hidden) - - return q, k, v - - def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: - if (self.linear_g is not None): - g = self.sigmoid(self.linear_g(q_x)) - - # [*, Q, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - o = o * g - - # [*, Q, H * C_hidden] - o = flatten_final_dims(o, 2) - - # [*, Q, C_q] - o = self.linear_o(o) - - return o - - def forward( - self, - q_x: torch.Tensor, - kv_x: torch.Tensor, - biases: Optional[List[torch.Tensor]] = None, - use_lma: bool = False, - q_chunk_size: Optional[int] = None, - kv_chunk_size: Optional[int] = None, - ) -> torch.Tensor: - """ - Args: - q_x: - [*, Q, C_q] query data - kv_x: - [*, K, C_k] key data - biases: - List of biases that broadcast to [*, H, Q, K] - use_lma: - Whether to use low-memory attention - q_chunk_size: - Query chunk size (for LMA) - kv_chunk_size: - Key/Value chunk size (for LMA) - Returns - [*, Q, C_q] attention update - """ - if (biases is None): - biases = [] - if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): - raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " - "be provided") - - q, k, v = self._prep_qkv(q_x, kv_x) - - if (use_lma): - biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] - - o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) - else: - o = _attention(q, k, v, biases) - - o = self._wrap_up(o, q_x) - - return o - - -class GlobalAttention(nn.Module): - - def __init__(self, c_in, c_hidden, no_heads, inf, eps): - super(GlobalAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.inf = inf - self.eps = eps - - self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") - - self.linear_k = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_v = Linear( - c_in, - c_hidden, - bias=False, - init="glorot", - ) - self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") - self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") - - self.sigmoid = nn.Sigmoid() - - def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # [*, N_res, C_in] - q = torch.sum(m * mask.unsqueeze(-1), - dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) - - # [*, N_res, H * C_hidden] - q = self.linear_q(q) - q *= (self.c_hidden**(-0.5)) - - # [*, N_res, H, C_hidden] - q = q.view(q.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, C_hidden] - k = self.linear_k(m) - v = self.linear_v(m) - - # [*, N_res, H, N_seq] - a = torch.matmul( - q, - k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] - ) - bias = (self.inf * (mask - 1))[..., :, None, :] - a += bias - a = softmax(a) - - # [*, N_res, H, C_hidden] - o = torch.matmul( - a, - v, - ) - - # [*, N_res, N_seq, C_hidden] - g = self.sigmoid(self.linear_g(m)) - - # [*, N_res, N_seq, H, C_hidden] - g = g.view(g.shape[:-1] + (self.no_heads, -1)) - - # [*, N_res, N_seq, H, C_hidden] - o = o.unsqueeze(-3) * g - - # [*, N_res, N_seq, H * C_hidden] - o = o.reshape(o.shape[:-2] + (-1,)) - - # [*, N_res, N_seq, C_in] - m = self.linear_o(o) - - return m - - -def _lma( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - biases: List[torch.Tensor], - q_chunk_size: int, - kv_chunk_size: int, -): - no_q, no_kv = q.shape[-3], k.shape[-3] - - # [*, Q, H, C_hidden] - o = q.new_zeros(q.shape) - for q_s in range(0, no_q, q_chunk_size): - q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] - large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] - - maxes = [] - weights = [] - values = [] - for kv_s in range(0, no_kv, kv_chunk_size): - k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] - v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] - small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] - - a = torch.einsum( - "...qhd,...khd->...hqk", - q_chunk, - k_chunk, - ) - - for b in small_bias_chunks: - a += b - - a = a.transpose(-2, -3) - - max_a = torch.max(a, dim=-1, keepdim=True)[0] - exp_a = torch.exp(a - max_a) - exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) - - maxes.append(max_a.detach().squeeze(-1)) - weights.append(torch.sum(exp_a, dim=-1)) - values.append(exp_v) - - chunk_max = torch.stack(maxes, dim=-3) - chunk_weights = torch.stack(weights, dim=-3) - chunk_values = torch.stack(values, dim=-4) - - global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] - max_diffs = torch.exp(chunk_max - global_max) - chunk_values *= max_diffs.unsqueeze(-1) - chunk_weights *= max_diffs - - all_values = torch.sum(chunk_values, dim=-4) - all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) - - q_chunk_out = all_values / all_weights - - o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out - - return o diff --git a/tests/test_autochunk/openfold/tensor_utils.py b/tests/test_autochunk/openfold/tensor_utils.py deleted file mode 100644 index 384a71fb5..000000000 --- a/tests/test_autochunk/openfold/tensor_utils.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import torch -import torch.nn as nn -from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - - -def flatten_final_dims(t: torch.Tensor, no_dims: int): - return t.reshape(t.shape[:-no_dims] + (-1,)) - - -def masked_mean(mask, value, dim, eps=1e-4): - mask = mask.expand(*value.shape) - return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) - - -def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): - boundaries = torch.linspace( - min_bin, max_bin, no_bins - 1, device=pts.device - ) - dists = torch.sqrt( - torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) - ) - return torch.bucketize(dists, boundaries) - - -def dict_multimap(fn, dicts): - first = dicts[0] - new_dict = {} - for k, v in first.items(): - all_v = [d[k] for d in dicts] - if type(v) is dict: - new_dict[k] = dict_multimap(fn, all_v) - else: - new_dict[k] = fn(all_v) - - return new_dict - - -def one_hot(x, v_bins): - reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) - diffs = x[..., None] - reshaped_bins - am = torch.argmin(torch.abs(diffs), dim=-1) - return nn.functional.one_hot(am, num_classes=len(v_bins)).float() - - -def batched_gather(data, inds, dim=0, no_batch_dims=0): - ranges = [] - for i, s in enumerate(data.shape[:no_batch_dims]): - r = torch.arange(s) - r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) - ranges.append(r) - - remaining_dims = [ - slice(None) for _ in range(len(data.shape) - no_batch_dims) - ] - remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds - ranges.extend(remaining_dims) - return data[ranges] - - -# With tree_map, a poor man's JAX tree_map -def dict_map(fn, dic, leaf_type): - new_dict = {} - for k, v in dic.items(): - if type(v) is dict: - new_dict[k] = dict_map(fn, v, leaf_type) - else: - new_dict[k] = tree_map(fn, v, leaf_type) - - return new_dict - - -def tree_map(fn, tree, leaf_type): - if isinstance(tree, dict): - return dict_map(fn, tree, leaf_type) - elif isinstance(tree, list): - return [tree_map(fn, x, leaf_type) for x in tree] - elif isinstance(tree, tuple): - return tuple([tree_map(fn, x, leaf_type) for x in tree]) - elif isinstance(tree, leaf_type): - return fn(tree) - else: - print(type(tree)) - raise ValueError("Not supported") - - -tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) - -def _fetch_dims(tree): - shapes = [] - tree_type = type(tree) - if tree_type is dict: - for v in tree.values(): - shapes.extend(_fetch_dims(v)) - elif tree_type is list or tree_type is tuple: - for t in tree: - shapes.extend(_fetch_dims(t)) - elif tree_type is torch.Tensor: - shapes.append(tree.shape) - else: - raise ValueError("Not supported") - - return shapes - - -@torch.jit.ignore -def _flat_idx_to_idx( - flat_idx: int, - dims: Tuple[int], -) -> Tuple[int]: - idx = [] - for d in reversed(dims): - idx.append(flat_idx % d) - flat_idx = flat_idx // d - - return tuple(reversed(idx)) - - -@torch.jit.ignore -def _get_minimal_slice_set( - start: Sequence[int], - end: Sequence[int], - dims: int, - start_edges: Optional[Sequence[bool]] = None, - end_edges: Optional[Sequence[bool]] = None, -) -> Sequence[Tuple[int]]: - """ - Produces an ordered sequence of tensor slices that, when used in - sequence on a tensor with shape dims, yields tensors that contain every - leaf in the contiguous range [start, end]. Care is taken to yield a - short sequence of slices, and perhaps even the shortest possible (I'm - pretty sure it's the latter). - - end is INCLUSIVE. - """ - # start_edges and end_edges both indicate whether, starting from any given - # dimension, the start/end index is at the top/bottom edge of the - # corresponding tensor, modeled as a tree - def reduce_edge_list(ll): - tally = 1 - for i in range(len(ll)): - reversed_idx = -1 * (i + 1) - ll[reversed_idx] *= tally - tally = ll[reversed_idx] - - if(start_edges is None): - start_edges = [s == 0 for s in start] - reduce_edge_list(start_edges) - if(end_edges is None): - end_edges = [e == (d - 1) for e,d in zip(end, dims)] - reduce_edge_list(end_edges) - - # Base cases. Either start/end are empty and we're done, or the final, - # one-dimensional tensor can be simply sliced - if(len(start) == 0): - return [tuple()] - elif(len(start) == 1): - return [(slice(start[0], end[0] + 1),)] - - slices = [] - path = [] - - # Dimensions common to start and end can be selected directly - for s,e in zip(start, end): - if(s == e): - path.append(slice(s, s + 1)) - else: - break - - path = tuple(path) - divergence_idx = len(path) - - # start == end, and we're done - if(divergence_idx == len(dims)): - return [tuple(path)] - - def upper(): - sdi = start[divergence_idx] - return [ - path + (slice(sdi, sdi + 1),) + s for s in - _get_minimal_slice_set( - start[divergence_idx + 1:], - [d - 1 for d in dims[divergence_idx + 1:]], - dims[divergence_idx + 1:], - start_edges=start_edges[divergence_idx + 1:], - end_edges=[1 for _ in end_edges[divergence_idx + 1:]] - ) - ] - - def lower(): - edi = end[divergence_idx] - return [ - path + (slice(edi, edi + 1),) + s for s in - _get_minimal_slice_set( - [0 for _ in start[divergence_idx + 1:]], - end[divergence_idx + 1:], - dims[divergence_idx + 1:], - start_edges=[1 for _ in start_edges[divergence_idx + 1:]], - end_edges=end_edges[divergence_idx + 1:], - ) - ] - - # If both start and end are at the edges of the subtree rooted at - # divergence_idx, we can just select the whole subtree at once - if(start_edges[divergence_idx] and end_edges[divergence_idx]): - slices.append( - path + (slice(start[divergence_idx], end[divergence_idx] + 1),) - ) - # If just start is at the edge, we can grab almost all of the subtree, - # treating only the ragged bottom edge as an edge case - elif(start_edges[divergence_idx]): - slices.append( - path + (slice(start[divergence_idx], end[divergence_idx]),) - ) - slices.extend(lower()) - # Analogous to the previous case, but the top is ragged this time - elif(end_edges[divergence_idx]): - slices.extend(upper()) - slices.append( - path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) - ) - # If both sides of the range are ragged, we need to handle both sides - # separately. If there's contiguous meat in between them, we can index it - # in one big chunk - else: - slices.extend(upper()) - middle_ground = end[divergence_idx] - start[divergence_idx] - if(middle_ground > 1): - slices.append( - path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) - ) - slices.extend(lower()) - - return [tuple(s) for s in slices] - - -@torch.jit.ignore -def _chunk_slice( - t: torch.Tensor, - flat_start: int, - flat_end: int, - no_batch_dims: int, -) -> torch.Tensor: - """ - Equivalent to - - t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] - - but without the need for the initial reshape call, which can be - memory-intensive in certain situations. The only reshape operations - in this function are performed on sub-tensors that scale with - (flat_end - flat_start), the chunk size. - """ - - batch_dims = t.shape[:no_batch_dims] - start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) - # _get_minimal_slice_set is inclusive - end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) - - # Get an ordered list of slices to perform - slices = _get_minimal_slice_set( - start_idx, - end_idx, - batch_dims, - ) - - sliced_tensors = [t[s] for s in slices] - - return torch.cat( - [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] - ) - - -def chunk_layer( - layer: Callable, - inputs: Dict[str, Any], - chunk_size: int, - no_batch_dims: int, - low_mem: bool = False, -) -> Any: - """ - Implements the "chunking" procedure described in section 1.11.8. - - Layer outputs and inputs are assumed to be simple "pytrees," - consisting only of (arbitrarily nested) lists, tuples, and dicts with - torch.Tensor leaves. - - Args: - layer: - The layer to be applied chunk-wise - inputs: - A (non-nested) dictionary of keyworded inputs. All leaves must - be tensors and must share the same batch dimensions. - chunk_size: - The number of sub-batches per chunk. If multiple batch - dimensions are specified, a "sub-batch" is defined as a single - indexing of all batch dimensions simultaneously (s.t. the - number of sub-batches is the product of the batch dimensions). - no_batch_dims: - How many of the initial dimensions of each input tensor can - be considered batch dimensions. - low_mem: - Avoids flattening potentially large input tensors. Unnecessary - in most cases, and is ever so slightly slower than the default - setting. - Returns: - The reassembled output of the layer on the inputs. - """ - if not (len(inputs) > 0): - raise ValueError("Must provide at least one input") - - initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] - orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) - - def _prep_inputs(t): - # TODO: make this more memory efficient. This sucks - if(not low_mem): - if not sum(t.shape[:no_batch_dims]) == no_batch_dims: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - t = t.reshape(-1, *t.shape[no_batch_dims:]) - else: - t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) - return t - - prepped_inputs = tensor_tree_map(_prep_inputs, inputs) - - flat_batch_dim = 1 - for d in orig_batch_dims: - flat_batch_dim *= d - - no_chunks = flat_batch_dim // chunk_size + ( - flat_batch_dim % chunk_size != 0 - ) - - i = 0 - out = None - for _ in range(no_chunks): - # Chunk the input - if(not low_mem): - select_chunk = ( - lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t - ) - else: - select_chunk = ( - partial( - _chunk_slice, - flat_start=i, - flat_end=min(flat_batch_dim, i + chunk_size), - no_batch_dims=len(orig_batch_dims) - ) - ) - - chunks = tensor_tree_map(select_chunk, prepped_inputs) - - # Run the layer on the chunk - output_chunk = layer(**chunks) - - # Allocate space for the output - if out is None: - allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) - out = tensor_tree_map(allocate, output_chunk) - - # Put the chunk in its pre-allocated space - out_type = type(output_chunk) - if out_type is dict: - def assign(d1, d2): - for k, v in d1.items(): - if type(v) is dict: - assign(v, d2[k]) - else: - v[i : i + chunk_size] = d2[k] - - assign(out, output_chunk) - elif out_type is tuple: - for x1, x2 in zip(out, output_chunk): - x1[i : i + chunk_size] = x2 - elif out_type is torch.Tensor: - out[i : i + chunk_size] = output_chunk - else: - raise ValueError("Not supported") - - i += chunk_size - - reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) - out = tensor_tree_map(reshape, out) - - return out diff --git a/tests/test_autochunk/openfold/triangular_attention.py b/tests/test_autochunk/openfold/triangular_attention.py deleted file mode 100644 index 12d09c502..000000000 --- a/tests/test_autochunk/openfold/triangular_attention.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod, partial -import math -from typing import Optional, List - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm, Attention -from .tensor_utils import ( - chunk_layer, - permute_final_dims, - flatten_final_dims, -) - - -class TriangleAttention(nn.Module): - def __init__( - self, c_in, c_hidden, no_heads, starting, inf=1e9 - ): - """ - Args: - c_in: - Input channel dimension - c_hidden: - Overall hidden channel dimension (not per-head) - no_heads: - Number of attention heads - """ - super(TriangleAttention, self).__init__() - - self.c_in = c_in - self.c_hidden = c_hidden - self.no_heads = no_heads - self.starting = starting - self.inf = inf - - self.layer_norm = LayerNorm(self.c_in) - - self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") - - self.mha = Attention( - self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads - ) - - @torch.jit.ignore - def _chunk(self, - x: torch.Tensor, - biases: List[torch.Tensor], - chunk_size: int, - ) -> torch.Tensor: - mha_inputs = { - "q_x": x, - "kv_x": x, - "biases": biases, - } - return chunk_layer( - partial(self.mha), - mha_inputs, - chunk_size=chunk_size, - no_batch_dims=len(x.shape[:-2]), - ) - - def forward(self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - chunk_size: Optional[int] = None - ) -> torch.Tensor: - """ - Args: - x: - [*, I, J, C_in] input tensor (e.g. the pair representation) - Returns: - [*, I, J, C_in] output tensor - """ - if mask is None: - # [*, I, J] - mask = x.new_ones( - x.shape[:-1], - ) - - # Shape annotations assume self.starting. Else, I and J are flipped - if not self.starting: - x = x.transpose(-2, -3) - mask = mask.transpose(-1, -2) - - # [*, I, J, C_in] - x = self.layer_norm(x) - - # [*, I, 1, 1, J] - mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] - - # [*, H, I, J] - triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) - - # [*, 1, H, I, J] - triangle_bias = triangle_bias.unsqueeze(-4) - - biases = [mask_bias, triangle_bias] - - if chunk_size is not None: - x = self._chunk(x, biases, chunk_size) - else: - x = self.mha(q_x=x, kv_x=x, biases=biases) - - if not self.starting: - x = x.transpose(-2, -3) - - return x - - -class TriangleAttentionStartingNode(TriangleAttention): - """ - Implements Algorithm 13. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=True) - - -class TriangleAttentionEndingNode(TriangleAttention): - """ - Implements Algorithm 14. - """ - - __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/tests/test_autochunk/openfold/triangular_multiplicative_update.py b/tests/test_autochunk/openfold/triangular_multiplicative_update.py deleted file mode 100644 index 29f7062c3..000000000 --- a/tests/test_autochunk/openfold/triangular_multiplicative_update.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2021 AlQuraishi Laboratory -# Copyright 2021 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partialmethod -from typing import Optional - -import torch -import torch.nn as nn - -from .primitives import Linear, LayerNorm -from .tensor_utils import permute_final_dims - - -class TriangleMultiplicativeUpdate(nn.Module): - """ - Implements Algorithms 11 and 12. - """ - def __init__(self, c_z, c_hidden, _outgoing=True): - """ - Args: - c_z: - Input channel dimension - c: - Hidden channel dimension - """ - super(TriangleMultiplicativeUpdate, self).__init__() - self.c_z = c_z - self.c_hidden = c_hidden - self._outgoing = _outgoing - - self.linear_a_p = Linear(self.c_z, self.c_hidden) - self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_b_p = Linear(self.c_z, self.c_hidden) - self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") - self.linear_g = Linear(self.c_z, self.c_z, init="gating") - self.linear_z = Linear(self.c_hidden, self.c_z, init="final") - - self.layer_norm_in = LayerNorm(self.c_z) - self.layer_norm_out = LayerNorm(self.c_hidden) - - self.sigmoid = nn.Sigmoid() - - def _combine_projections(self, - a: torch.Tensor, - b: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError("This method needs to be overridden") - - def forward(self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Args: - x: - [*, N_res, N_res, C_z] input tensor - mask: - [*, N_res, N_res] input mask - Returns: - [*, N_res, N_res, C_z] output tensor - """ - if mask is None: - mask = z.new_ones(z.shape[:-1]) - - mask = mask.unsqueeze(-1) - - z = self.layer_norm_in(z) - a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) - a = a * mask - b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) - b = b * mask - x = self._combine_projections(a, b) - x = self.layer_norm_out(x) - x = self.linear_z(x) - g = self.sigmoid(self.linear_g(z)) - z = x * g - - return z - - -class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 11. - """ - def _combine_projections(self, - a: torch.Tensor, # [*, N_i, N_k, C] - b: torch.Tensor, # [*, N_j, N_k, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 0, 1)), - permute_final_dims(b, (2, 1, 0)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) - - -class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): - """ - Implements Algorithm 12. - """ - def _combine_projections(self, - a: torch.Tensor, # [*, N_k, N_i, C] - b: torch.Tensor, # [*, N_k, N_j, C] - ): - # [*, C, N_i, N_j] - p = torch.matmul( - permute_final_dims(a, (2, 1, 0)), - permute_final_dims(b, (2, 0, 1)), - ) - - # [*, N_i, N_j, C] - return permute_final_dims(p, (1, 2, 0)) - diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py new file mode 100644 index 000000000..1273bf2fe --- /dev/null +++ b/tests/test_autochunk/test_evoformer_codegen.py @@ -0,0 +1,164 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True +except: + HAS_REPO = False + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if CODEGEN_AVAILABLE and is_compatible_with_meta(): + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): + # for memory test + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # gm(node1, pair1) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print( + # "autochunk now mem:%.2f max mem:%.2f" + # % (new_now_mem - now_mem, new_max_mem - now_mem) + # ) + + # test forward + model = model.cuda() + with torch.no_grad(): + non_fx_out = model(node, pair, node_mask, pair_mask) + fx_out = gm(node, pair, node_mask, pair_mask) + + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) + + +def _build_openfold(): + model = EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ).eval().cuda() + return model + + +def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory): + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = _build_openfold() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + interp = MetaInfoProp(meta_graph) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), + MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), + MetaTensor(pair_mask, fake_device="cuda:0"), + ) + # codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + # graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert we have inserted chunk + code = graph.python_code("self").src + assert "chunk_size" in code + # print(code) + + _test_fwd(model, gm, node, pair, node_mask, pair_mask) + gpc.destroy() + + +@pytest.mark.skipif( + not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_evoformer_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_evoformer_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_evoformer_codegen(0, 32, 64, 25) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py similarity index 88% rename from tests/test_autochunk/test_autochunk_codegen.py rename to tests/test_autochunk/test_simple_evoformer_codegen.py index 02fa07e2c..f1272330f 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_simple_evoformer_codegen.py @@ -5,6 +5,12 @@ import torch import torch.fx import torch.multiprocessing as mp +try: + from simple_evoformer import base_evoformer + HAS_REPO = True +except: + HAS_REPO = False + import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer @@ -13,7 +19,6 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -from tests.test_autochunk.evoformer.evoformer import evoformer_base if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -48,7 +53,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): torch.abs(non_fx_out[1] - fx_out[1])) -def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): +def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -60,7 +65,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): ) # build model and input - model = evoformer_base().cuda() + model = base_evoformer().cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() @@ -95,13 +100,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason='torch version is lower than 1.12.0') @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_autochunk_codegen(msa_len, pair_len, max_memory): +def test_simple_evoformer_codegen(msa_len, pair_len, max_memory): run_func = partial( - _test_autochunk_codegen, + _test_simple_evoformer_codegen, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -110,4 +116,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_codegen(0, 32, 64, 25) + _test_simple_evoformer_codegen(0, 32, 64, 25) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_simple_evoformer_search.py similarity index 87% rename from tests/test_autochunk/test_autochunk_search.py rename to tests/test_autochunk/test_simple_evoformer_search.py index 371fce64f..04fb514fb 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_simple_evoformer_search.py @@ -5,13 +5,18 @@ import torch import torch.fx import torch.multiprocessing as mp +try: + from simple_evoformer import base_evoformer + HAS_REPO = True +except: + HAS_REPO = False + import colossalai from colossalai.core import global_context as gpc from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.utils import free_port -from tests.test_autochunk.evoformer.evoformer import evoformer_base if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen @@ -57,7 +62,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): ) -def _test_autochunk_search(rank, msa_len, pair_len, max_memory): +def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory): # launch colossalai colossalai.launch( config={}, @@ -69,7 +74,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): ) # build model and input - model = evoformer_base().cuda() + model = base_evoformer().cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() @@ -84,13 +89,14 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0") +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0") @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) -def test_autochunk_search(msa_len, pair_len, max_memory): +def test_simple_evoformer_search(msa_len, pair_len, max_memory): run_func = partial( - _test_autochunk_search, + _test_simple_evoformer_search, msa_len=msa_len, pair_len=pair_len, max_memory=max_memory, @@ -99,4 +105,4 @@ def test_autochunk_search(msa_len, pair_len, max_memory): if __name__ == "__main__": - _test_autochunk_search(0, 32, 64, 20) + _test_simple_evoformer_search(0, 32, 64, 20) diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py index 6b58aa801..b405f8cd2 100644 --- a/tests/test_tensor/common_utils/_utils.py +++ b/tests/test_tensor/common_utils/_utils.py @@ -4,6 +4,7 @@ import random import numpy as np import torch import torch.distributed as dist +from torch.testing import assert_close from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): return tensor_chunk.clone() -def tensor_equal(A, B): - return torch.allclose(A, B, rtol=1e-3, atol=1e-1) +def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1): + assert_close(t_a, t_b, rtol=rtol, atol=atol) + return True -def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size): +def tensor_shard_equal(tensor: torch.Tensor, + shard: torch.Tensor, + rank: int, + world_size: int, + rtol: float = 1e-3, + atol: float = 1e-1): assert tensor.ndim == shard.ndim if tensor.shape == shard.shape: - return tensor_equal(tensor, shard) + return tensor_equal(tensor, shard, rtol, atol) else: dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) if dims_not_eq.numel() == 1: @@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) if rank is None: rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol) else: raise NotImplementedError diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py new file mode 100644 index 000000000..8ba6e3cb6 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -0,0 +1,98 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import LowLevelZeroOptimizer +from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal + + +def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4): + return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol) + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(32, 128) + self.act = nn.GELU() + self.linear2 = nn.Linear(128, 32) + + def forward(self, x): + y = self.linear1(x) + y = self.act(y) + y = self.linear2(y) + return x + y + + +@parameterize("overlap_flag", [False, True]) +@parameterize("partition_flag", [False, True]) +def exam_zero_with_tp(overlap_flag, partition_flag): + set_seed(233010) + tp_pg = ProcessGroup(tp_degree=2) + + with ColoInitContext(device=get_current_device(), default_pg=tp_pg): + hybrid_model = TestModel() + torch_model = TestModel().cuda() + for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()): + pt.data.copy_(ph.data) + + for name, param in hybrid_model.named_parameters(): + if 'linear1' in name: + split_param_row_tp1d(param, tp_pg) + param.compute_spec.set_output_replicate(False) + if 'linear2.weight' in name: + split_param_col_tp1d(param, tp_pg) + + torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group()) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1) + hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1) + hybrid_optim = LowLevelZeroOptimizer(hybrid_optim, + initial_scale=1, + overlap_communication=overlap_flag, + partition_grad=partition_flag) + + dp_local_rank = tp_pg.dp_local_rank() + set_seed(255 + dp_local_rank) + + data = torch.randn(8, 32, device=get_current_device()) + torch_loss = torch_model(data).sum() + hybrid_loss = hybrid_model(data).sum() + assert_close(torch_loss, hybrid_loss) + + torch_loss.backward() + hybrid_optim.backward(hybrid_loss) + hybrid_optim.sync_grad() + + torch_optim.step() + hybrid_optim.step() + + for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()): + assert strict_shard_equal(pt.data, ph.data, tp_pg) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + exam_zero_with_tp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_zero_with_tp(): + world_size = 4 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_with_tp()