import warnings from typing import List from .extensions import ( CpuAdamArmExtension, CpuAdamX86Extension, FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension, FusedOptimizerCudaExtension, InferenceOpsCudaExtension, LayerNormCudaExtension, MoeCudaExtension, ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, ) from .extensions.base_extension import _Extension __all__ = [ "KernelLoader", "CPUAdamLoader", "LayerNormLoader", "MoeLoader", "FusedOptimizerLoader", "InferenceOpsLoader", "ScaledMaskedSoftmaxLoader", "ScaledUpperTriangleMaskedSoftmaxLoader", ] class KernelLoader: """ An abstract class which offers encapsulation to the kernel loading process. Usage: kernel_loader = KernelLoader() kernel = kernel_loader.load() """ REGISTRY: List[_Extension] = [] @classmethod def register_extension(cls, extension: _Extension): """ This classmethod is an extension point which allows users to register their customized kernel implementations to the loader. Args: extension (_Extension): the extension to be registered. """ cls.REGISTRY.append(extension) def load(self, ext_name: str = None): """ Load the kernel according to the current machine. Args: ext_name (str): the name of the extension to be loaded. If not specified, the loader will try to look for an kernel available on the current machine. """ exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] # look for exts which can be built/loaded on the current machine if ext_name: usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) else: usable_exts = [] for ext in exts: if ext.is_available(): # make sure the machine is compatible during kernel loading ext.assert_compatible() usable_exts.append(ext) assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." if len(usable_exts) > 1: # if more than one usable kernel is found, we will try to load the kernel with the highest priority usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) warnings.warn( f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" ) return usable_exts[0].load() class CPUAdamLoader(KernelLoader): REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension] class LayerNormLoader(KernelLoader): REGISTRY = [LayerNormCudaExtension] class MoeLoader(KernelLoader): REGISTRY = [MoeCudaExtension] class FusedOptimizerLoader(KernelLoader): REGISTRY = [FusedOptimizerCudaExtension] class InferenceOpsLoader(KernelLoader): REGISTRY = [InferenceOpsCudaExtension] class ScaledMaskedSoftmaxLoader(KernelLoader): REGISTRY = [ScaledMaskedSoftmaxCudaExtension] class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension] class FlashAttentionLoader(KernelLoader): REGISTRY = [ FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionSdpaCudaExtension, ] class FlashAttentionWithCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader): REGISTRY = [FlashAttentionSdpaCudaExtension]