from .base_extension import _Extension __all__ = ["_TritonExtension"] class _TritonExtension(_Extension): def __init__(self, name: str, priority: int = 1): super().__init__(name, support_aot=False, support_jit=True, priority=priority) def is_hardware_compatible(self) -> bool: # cuda extension can only be built if cuda is available try: import torch cuda_available = torch.cuda.is_available() except: cuda_available = False return cuda_available def load(self): return self.build_jit()