mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.6 KiB
55 lines
1.6 KiB
from typing import Callable
|
|
|
|
import torch
|
|
|
|
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
|
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
|
|
|
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
|
|
META_COMPATIBILITY = False
|
|
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
|
|
from . import _meta_regist_12
|
|
META_COMPATIBILITY = True
|
|
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
|
|
from . import _meta_regist_13
|
|
META_COMPATIBILITY = True
|
|
elif TORCH_MAJOR == 2:
|
|
META_COMPATIBILITY = True
|
|
|
|
|
|
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
|
"""A decorator to make a function compatible with different versions of PyTorch.
|
|
|
|
Args:
|
|
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
|
|
|
|
Returns:
|
|
Callable: The decorated function
|
|
"""
|
|
|
|
def decorator(func):
|
|
if META_COMPATIBILITY:
|
|
return func
|
|
else:
|
|
if is_backward_compatible:
|
|
return func
|
|
else:
|
|
|
|
def wrapper(*args, **kwargs):
|
|
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def is_compatible_with_meta() -> bool:
|
|
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
|
|
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
|
|
experimental counterparts.
|
|
|
|
Returns:
|
|
bool: The meta compatibility
|
|
"""
|
|
return META_COMPATIBILITY
|