|
|
|
@ -118,7 +118,7 @@ class ColoAttention:
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND: |
|
|
|
|
FlashAttentionDaoLoader().load() |
|
|
|
|
flash_kernel = FlashAttentionDaoLoader().load() |
|
|
|
|
# lazy load |
|
|
|
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): |
|
|
|
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ |
|
|
|
@ -126,7 +126,7 @@ class ColoAttention:
|
|
|
|
|
].load() |
|
|
|
|
|
|
|
|
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): |
|
|
|
|
return FlashAttentionDaoLoader() |
|
|
|
|
return flash_kernel |
|
|
|
|
else: |
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type] |
|
|
|
|
|
|
|
|
|