mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.3 KiB
47 lines
1.3 KiB
2 years ago
|
from typing import Callable
|
||
|
|
||
|
import torch
|
||
|
|
||
|
try:
|
||
|
from . import _meta_registrations
|
||
|
META_COMPATIBILITY = True
|
||
|
except:
|
||
|
META_COMPATIBILITY = False
|
||
|
|
||
|
|
||
|
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
||
|
"""A decorator to make a function compatible with different versions of PyTorch.
|
||
|
|
||
|
Args:
|
||
|
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.
|
||
|
|
||
|
Returns:
|
||
|
Callable: The decorated function
|
||
|
"""
|
||
|
|
||
|
def decorator(func):
|
||
|
if META_COMPATIBILITY:
|
||
|
return func
|
||
|
else:
|
||
|
if is_backward_compatible:
|
||
|
return func
|
||
|
else:
|
||
|
|
||
|
def wrapper(*args, **kwargs):
|
||
|
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
def is_compatible_with_meta() -> bool:
|
||
|
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
|
||
|
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
|
||
|
experimental counterparts.
|
||
|
|
||
|
Returns:
|
||
|
bool: The meta compatibility
|
||
|
"""
|
||
|
return META_COMPATIBILITY
|