mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
0b14a5512e
commit
0ad3129cb9
|
@ -80,6 +80,7 @@ def get_pad_info(
|
|||
|
||||
class ColoAttention:
|
||||
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||
|
||||
@staticmethod
|
||||
def _init_kernels_dispatch():
|
||||
|
@ -105,6 +106,8 @@ class ColoAttention:
|
|||
torch.bfloat16: half_dispatch_map,
|
||||
torch.float32: float_dispatch_map,
|
||||
}
|
||||
if ColoAttention._flash_kernel_dispatch is None:
|
||||
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
|
||||
|
||||
@staticmethod
|
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
||||
|
@ -118,7 +121,7 @@ class ColoAttention:
|
|||
)
|
||||
|
||||
if size >= MEMORY_BOUND:
|
||||
flash_kernel = FlashAttentionDaoLoader().load()
|
||||
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
|
||||
# lazy load
|
||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||
|
@ -126,7 +129,7 @@ class ColoAttention:
|
|||
].load()
|
||||
|
||||
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
||||
return flash_kernel
|
||||
return ColoAttention._flash_kernel_dispatch
|
||||
else:
|
||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
|
||||
|
|
Loading…
Reference in New Issue