mirror of https://github.com/hpcaitech/ColossalAI
adapt new fx
parent
e532679c95
commit
7ab2db206f
|
@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE:
|
|||
code = "".join(body)
|
||||
code = "\n".join(" " + line for line in code.split("\n"))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
{wrap_stmts}
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
{prologue}
|
||||
{code}"""
|
||||
# print(fn_code)
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
|
|
@ -28,12 +28,7 @@ class EstimateMemory(object):
|
|||
return x
|
||||
|
||||
def _get_output_node(self, n):
|
||||
fwd_out = {
|
||||
x.uuid: x
|
||||
for x in n.meta["fwd_out"]
|
||||
if isinstance(x, torch.Tensor) and hasattr(x, "uuid")
|
||||
}
|
||||
out_size = activation_size(fwd_out)
|
||||
out_size = activation_size(n.meta["fwd_out"])
|
||||
out_node = [n.name] if out_size > 0 else []
|
||||
return out_size, out_node
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.multiprocessing as mp
|
|||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
@ -15,8 +16,9 @@ from colossalai.fx.profiler import MetaTensor
|
|||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
|
@ -102,7 +104,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
|||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
|
|
|
@ -7,14 +7,15 @@ import torch.multiprocessing as mp
|
|||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||
|
@ -89,7 +90,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
|||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
|
|
Loading…
Reference in New Issue