mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] meta_tensor_compatibility_with_torch2
parent
15a74da79c
commit
fbd2a9e05b
|
@ -14,9 +14,7 @@ elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
|
||||||
from . import _meta_regist_13
|
from . import _meta_regist_13
|
||||||
META_COMPATIBILITY = True
|
META_COMPATIBILITY = True
|
||||||
elif TORCH_MAJOR == 2:
|
elif TORCH_MAJOR == 2:
|
||||||
from . import _meta_regist_13
|
|
||||||
META_COMPATIBILITY = True
|
META_COMPATIBILITY = True
|
||||||
raise UserWarning("Colossalai is not tested with torch2.0 yet!!!")
|
|
||||||
|
|
||||||
|
|
||||||
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
||||||
|
|
|
@ -223,7 +223,8 @@ def zero_flop_jit(*args):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
|
||||||
|
torch.__version__) < version.parse('2.0.0'):
|
||||||
flop_mapping = {
|
flop_mapping = {
|
||||||
# gemm, gemv and dot
|
# gemm, gemv and dot
|
||||||
aten.mm.default: matmul_flop_jit,
|
aten.mm.default: matmul_flop_jit,
|
||||||
|
|
Loading…
Reference in New Issue