pull/6061/head
wangbluo 2 months ago
parent dc032172c3
commit 0b14a5512e

@ -118,7 +118,7 @@ class ColoAttention:
)
if size >= MEMORY_BOUND:
FlashAttentionDaoLoader().load()
flash_kernel = FlashAttentionDaoLoader().load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
@ -126,7 +126,7 @@ class ColoAttention:
].load()
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return FlashAttentionDaoLoader()
return flash_kernel
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

Loading…
Cancel
Save