diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 6caad920d..0444a4816 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -14,9 +14,7 @@ elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: from . import _meta_regist_13 META_COMPATIBILITY = True elif TORCH_MAJOR == 2: - from . import _meta_regist_13 META_COMPATIBILITY = True - raise UserWarning("Colossalai is not tested with torch2.0 yet!!!") def compatibility(is_backward_compatible: bool = False) -> Callable: diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 407a6bed5..ba090a2ec 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -223,7 +223,8 @@ def zero_flop_jit(*args): return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( + torch.__version__) < version.parse('2.0.0'): flop_mapping = { # gemm, gemv and dot aten.mm.default: matmul_flop_jit,