[hotfix] meta_tensor_compatibility_with_torch2

pull/3338/head
YuliangLiu0306 2023-03-30 11:22:20 +08:00 committed by アマデウス
parent 15a74da79c
commit fbd2a9e05b
2 changed files with 2 additions and 3 deletions

View File

@ -14,9 +14,7 @@ elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
from . import _meta_regist_13
META_COMPATIBILITY = True
raise UserWarning("Colossalai is not tested with torch2.0 yet!!!")
def compatibility(is_backward_compatible: bool = False) -> Callable:

View File

@ -223,7 +223,8 @@ def zero_flop_jit(*args):
return 0
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
flop_mapping = {
# gemm, gemv and dot
aten.mm.default: matmul_flop_jit,