mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.5 KiB
52 lines
1.5 KiB
from typing import Callable |
|
|
|
import torch |
|
|
|
TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
|
TORCH_MINOR = int(torch.__version__.split(".")[1]) |
|
|
|
if TORCH_MAJOR == 1 and TORCH_MINOR < 12: |
|
META_COMPATIBILITY = False |
|
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: |
|
META_COMPATIBILITY = True |
|
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: |
|
META_COMPATIBILITY = True |
|
elif TORCH_MAJOR == 2: |
|
META_COMPATIBILITY = True |
|
|
|
|
|
def compatibility(is_backward_compatible: bool = False) -> Callable: |
|
"""A decorator to make a function compatible with different versions of PyTorch. |
|
|
|
Args: |
|
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False. |
|
|
|
Returns: |
|
Callable: The decorated function |
|
""" |
|
|
|
def decorator(func): |
|
if META_COMPATIBILITY: |
|
return func |
|
else: |
|
if is_backward_compatible: |
|
return func |
|
else: |
|
|
|
def wrapper(*args, **kwargs): |
|
raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}") |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
def is_compatible_with_meta() -> bool: |
|
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx` |
|
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its |
|
experimental counterparts. |
|
|
|
Returns: |
|
bool: The meta compatibility |
|
""" |
|
return META_COMPATIBILITY
|
|
|