diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 11e194fc0..4a4bbef4c 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -4,13 +4,13 @@ from typing import List, Callable, Any, Tuple, Dict try: from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods - codegen_available = True + CODEGEN_AVAILABLE = True except: from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - codegen_available = False + CODEGEN_AVAILABLE = False -if codegen_available: +if CODEGEN_AVAILABLE: __all__ = ['ActivationCheckpointCodeGen'] else: __all__ = ['python_code_with_activation_checkpoint'] @@ -169,7 +169,7 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu within_ckpt_region = False -if codegen_available: +if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen):