diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c755ffa2f..129b04fb7 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -80,6 +80,7 @@ def get_pad_info( class ColoAttention: _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None + _flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None @staticmethod def _init_kernels_dispatch(): @@ -105,6 +106,8 @@ class ColoAttention: torch.bfloat16: half_dispatch_map, torch.float32: float_dispatch_map, } + if ColoAttention._flash_kernel_dispatch is None: + ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader() @staticmethod def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable: @@ -118,7 +121,7 @@ class ColoAttention: ) if size >= MEMORY_BOUND: - flash_kernel = FlashAttentionDaoLoader().load() + ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ @@ -126,7 +129,7 @@ class ColoAttention: ].load() if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL): - return flash_kernel + return ColoAttention._flash_kernel_dispatch else: return ColoAttention._kernel_dispatch_map[dtype][mask_type]