|
|
|
@ -118,7 +118,7 @@ class ColoAttention:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND:
|
|
|
|
|
FlashAttentionDaoLoader().load()
|
|
|
|
|
flash_kernel = FlashAttentionDaoLoader().load()
|
|
|
|
|
# lazy load
|
|
|
|
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
|
|
|
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
|
|
|
@ -126,7 +126,7 @@ class ColoAttention:
|
|
|
|
|
].load()
|
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
|
|
|
|
return FlashAttentionDaoLoader()
|
|
|
|
|
return flash_kernel
|
|
|
|
|
else:
|
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
|
|
|
|
|
|
|
|