adapt new fx

pull/2364/head
oahzxl 2023-01-10 11:56:00 +08:00
parent e532679c95
commit 7ab2db206f
4 changed files with 12 additions and 14 deletions

View File

@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE:
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{wrap_stmts}
{prologue}
{code}"""
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)

View File

@ -28,12 +28,7 @@ class EstimateMemory(object):
return x
def _get_output_node(self, n):
fwd_out = {
x.uuid: x
for x in n.meta["fwd_out"]
if isinstance(x, torch.Tensor) and hasattr(x, "uuid")
}
out_size = activation_size(fwd_out)
out_size = activation_size(n.meta["fwd_out"])
out_node = [n.name] if out_size > 0 else []
return out_size, out_node

View File

@ -8,6 +8,7 @@ import torch.multiprocessing as mp
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
@ -15,8 +16,9 @@ from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE:
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
@ -102,7 +104,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
gpc.destroy()
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])

View File

@ -7,14 +7,15 @@ import torch.multiprocessing as mp
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
if CODEGEN_AVAILABLE:
if CODEGEN_AVAILABLE and is_compatible_with_meta():
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
@ -89,7 +90,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
gpc.destroy()
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason="torch version is lower than 1.12.0")
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])