pull/6061/head
wangbluo 2024-09-13 09:01:26 +00:00
parent 0b14a5512e
commit 0ad3129cb9
1 changed files with 5 additions and 2 deletions

View File

@ -80,6 +80,7 @@ def get_pad_info(
class ColoAttention: class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
@staticmethod @staticmethod
def _init_kernels_dispatch(): def _init_kernels_dispatch():
@ -105,6 +106,8 @@ class ColoAttention:
torch.bfloat16: half_dispatch_map, torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map, torch.float32: float_dispatch_map,
} }
if ColoAttention._flash_kernel_dispatch is None:
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
@staticmethod @staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
@ -118,7 +121,7 @@ class ColoAttention:
) )
if size >= MEMORY_BOUND: if size >= MEMORY_BOUND:
flash_kernel = FlashAttentionDaoLoader().load() ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
# lazy load # lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
@ -126,7 +129,7 @@ class ColoAttention:
].load() ].load()
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return flash_kernel return ColoAttention._flash_kernel_dispatch
else: else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type] return ColoAttention._kernel_dispatch_map[dtype][mask_type]