mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
21 lines
589 B
21 lines
589 B
from .base_extension import _Extension |
|
|
|
__all__ = ["_TritonExtension"] |
|
|
|
|
|
class _TritonExtension(_Extension): |
|
def __init__(self, name: str, priority: int = 1): |
|
super().__init__(name, support_aot=False, support_jit=True, priority=priority) |
|
|
|
def is_hardware_compatible(self) -> bool: |
|
# cuda extension can only be built if cuda is available |
|
try: |
|
import torch |
|
|
|
cuda_available = torch.cuda.is_available() |
|
except: |
|
cuda_available = False |
|
return cuda_available |
|
|
|
def load(self): |
|
return self.build_jit()
|
|
|