|
|
|
import warnings
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from .extensions import (
|
|
|
|
CpuAdamArmExtension,
|
|
|
|
CpuAdamX86Extension,
|
|
|
|
FlashAttentionDaoCudaExtension,
|
|
|
|
FlashAttentionNpuExtension,
|
|
|
|
FlashAttentionSdpaCudaExtension,
|
|
|
|
FusedOptimizerCudaExtension,
|
|
|
|
InferenceOpsCudaExtension,
|
|
|
|
LayerNormCudaExtension,
|
|
|
|
MoeCudaExtension,
|
|
|
|
ScaledMaskedSoftmaxCudaExtension,
|
|
|
|
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
|
|
|
|
)
|
|
|
|
from .extensions.base_extension import _Extension
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"KernelLoader",
|
|
|
|
"CPUAdamLoader",
|
|
|
|
"LayerNormLoader",
|
|
|
|
"MoeLoader",
|
|
|
|
"FusedOptimizerLoader",
|
|
|
|
"InferenceOpsLoader",
|
|
|
|
"ScaledMaskedSoftmaxLoader",
|
|
|
|
"ScaledUpperTriangleMaskedSoftmaxLoader",
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class KernelLoader:
|
|
|
|
"""
|
|
|
|
An abstract class which offers encapsulation to the kernel loading process.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
kernel_loader = KernelLoader()
|
|
|
|
kernel = kernel_loader.load()
|
|
|
|
"""
|
|
|
|
|
|
|
|
REGISTRY: List[_Extension] = []
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def register_extension(cls, extension: _Extension):
|
|
|
|
"""
|
|
|
|
This classmethod is an extension point which allows users to register their customized
|
|
|
|
kernel implementations to the loader.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
extension (_Extension): the extension to be registered.
|
|
|
|
"""
|
|
|
|
cls.REGISTRY.append(extension)
|
|
|
|
|
|
|
|
def load(self, ext_name: str = None):
|
|
|
|
"""
|
|
|
|
Load the kernel according to the current machine.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ext_name (str): the name of the extension to be loaded. If not specified, the loader
|
|
|
|
will try to look for an kernel available on the current machine.
|
|
|
|
"""
|
|
|
|
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
|
|
|
|
|
|
|
|
# look for exts which can be built/loaded on the current machine
|
|
|
|
|
|
|
|
if ext_name:
|
|
|
|
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
|
|
|
|
else:
|
|
|
|
usable_exts = []
|
|
|
|
for ext in exts:
|
|
|
|
if ext.is_available():
|
|
|
|
# make sure the machine is compatible during kernel loading
|
|
|
|
ext.assert_compatible()
|
|
|
|
usable_exts.append(ext)
|
|
|
|
|
|
|
|
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
|
|
|
|
|
|
|
|
if len(usable_exts) > 1:
|
|
|
|
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
|
|
|
|
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
|
|
|
|
warnings.warn(
|
|
|
|
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
|
|
|
|
)
|
|
|
|
return usable_exts[0].load()
|
|
|
|
|
|
|
|
|
|
|
|
class CPUAdamLoader(KernelLoader):
|
|
|
|
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNormLoader(KernelLoader):
|
|
|
|
REGISTRY = [LayerNormCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class MoeLoader(KernelLoader):
|
|
|
|
REGISTRY = [MoeCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class FusedOptimizerLoader(KernelLoader):
|
|
|
|
REGISTRY = [FusedOptimizerCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class InferenceOpsLoader(KernelLoader):
|
|
|
|
REGISTRY = [InferenceOpsCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class ScaledMaskedSoftmaxLoader(KernelLoader):
|
|
|
|
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
|
|
|
|
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class FlashAttentionLoader(KernelLoader):
|
|
|
|
REGISTRY = [
|
|
|
|
FlashAttentionNpuExtension,
|
|
|
|
FlashAttentionDaoCudaExtension,
|
|
|
|
FlashAttentionSdpaCudaExtension,
|
|
|
|
]
|
|
|
|
|
|
|
|
class FlashAttentionDaoLoader(KernelLoader):
|
|
|
|
REGISTRY = [FlashAttentionDaoCudaExtension]
|
|
|
|
|
|
|
|
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
|
|
|
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
|
|
|
|
|
|
|
|
|
|
|
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
|
|
|
|
REGISTRY = [FlashAttentionSdpaCudaExtension]
|