pull/6061/head
wangbluo 2024-09-13 09:01:26 +00:00
parent 0b14a5512e
commit 0ad3129cb9
1 changed files with 5 additions and 2 deletions

View File

@ -80,6 +80,7 @@ def get_pad_info(
class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
@staticmethod
def _init_kernels_dispatch():
@ -105,6 +106,8 @@ class ColoAttention:
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
if ColoAttention._flash_kernel_dispatch is None:
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
@ -118,7 +121,7 @@ class ColoAttention:
)
if size >= MEMORY_BOUND:
flash_kernel = FlashAttentionDaoLoader().load()
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
@ -126,7 +129,7 @@ class ColoAttention:
].load()
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return flash_kernel
return ColoAttention._flash_kernel_dispatch
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]