mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
dc032172c3
commit
0b14a5512e
|
@ -118,7 +118,7 @@ class ColoAttention:
|
||||||
)
|
)
|
||||||
|
|
||||||
if size >= MEMORY_BOUND:
|
if size >= MEMORY_BOUND:
|
||||||
FlashAttentionDaoLoader().load()
|
flash_kernel = FlashAttentionDaoLoader().load()
|
||||||
# lazy load
|
# lazy load
|
||||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||||
|
@ -126,7 +126,7 @@ class ColoAttention:
|
||||||
].load()
|
].load()
|
||||||
|
|
||||||
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
||||||
return FlashAttentionDaoLoader()
|
return flash_kernel
|
||||||
else:
|
else:
|
||||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue