from typing import Callable import torch TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: META_COMPATIBILITY = True elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: META_COMPATIBILITY = True elif TORCH_MAJOR == 2: META_COMPATIBILITY = True def compatibility(is_backward_compatible: bool = False) -> Callable: """A decorator to make a function compatible with different versions of PyTorch. Args: is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False. Returns: Callable: The decorated function """ def decorator(func): if META_COMPATIBILITY: return func else: if is_backward_compatible: return func else: def wrapper(*args, **kwargs): raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}") return wrapper return decorator def is_compatible_with_meta() -> bool: """Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx` modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its experimental counterparts. Returns: bool: The meta compatibility """ return META_COMPATIBILITY