diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 14f17b1d3..e8af9bde8 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE: code = "".join(body) code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" - {wrap_stmts} +{wrap_stmts} - {prologue} - {code}""" +{prologue} +{code}""" # print(fn_code) return PythonCode(fn_code, globals_) diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 62b23cf9f..e001423f1 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -28,12 +28,7 @@ class EstimateMemory(object): return x def _get_output_node(self, n): - fwd_out = { - x.uuid: x - for x in n.meta["fwd_out"] - if isinstance(x, torch.Tensor) and hasattr(x, "uuid") - } - out_size = activation_size(fwd_out) + out_size = activation_size(n.meta["fwd_out"]) out_node = [n.name] if out_size > 0 else [] return out_size, out_node diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 28999706b..fe1916884 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -8,6 +8,7 @@ import torch.multiprocessing as mp import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -15,8 +16,9 @@ from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base -if CODEGEN_AVAILABLE: +if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): @@ -102,7 +104,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0') @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64]) diff --git a/tests/test_autochunk/test_autochunk_search.py b/tests/test_autochunk/test_autochunk_search.py index eb2bf4560..537bf4f41 100644 --- a/tests/test_autochunk/test_autochunk_search.py +++ b/tests/test_autochunk/test_autochunk_search.py @@ -7,14 +7,15 @@ import torch.multiprocessing as mp import colossalai from colossalai.core import global_context as gpc +from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import MetaTensor from colossalai.utils import free_port from tests.test_autochunk.evoformer.evoformer import evoformer_base -if CODEGEN_AVAILABLE: +if CODEGEN_AVAILABLE and is_compatible_with_meta(): from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len): @@ -89,7 +90,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory): gpc.destroy() -@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason="torch version is lower than 1.12.0") +@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0") @pytest.mark.parametrize("max_memory", [None, 20, 25, 30]) @pytest.mark.parametrize("msa_len", [32]) @pytest.mark.parametrize("pair_len", [64])