Browse Source

fix

pull/6061/head
wangbluo 2 months ago
parent
commit
0b14a5512e
  1. 4
      colossalai/shardformer/layer/attn.py

4
colossalai/shardformer/layer/attn.py

@ -118,7 +118,7 @@ class ColoAttention:
)
if size >= MEMORY_BOUND:
FlashAttentionDaoLoader().load()
flash_kernel = FlashAttentionDaoLoader().load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
@ -126,7 +126,7 @@ class ColoAttention:
].load()
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return FlashAttentionDaoLoader()
return flash_kernel
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

Loading…
Cancel
Save