ColossalAI/extensions/triton_extension.py

22 lines
589 B
Python

from .base_extension import _Extension
__all__ = ["_TritonExtension"]
class _TritonExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=False, support_jit=True, priority=priority)
def is_hardware_compatible(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def load(self):
return self.build_jit()