mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
0b14a5512e
commit
0ad3129cb9
|
@ -80,6 +80,7 @@ def get_pad_info(
|
||||||
|
|
||||||
class ColoAttention:
|
class ColoAttention:
|
||||||
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||||
|
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_kernels_dispatch():
|
def _init_kernels_dispatch():
|
||||||
|
@ -105,6 +106,8 @@ class ColoAttention:
|
||||||
torch.bfloat16: half_dispatch_map,
|
torch.bfloat16: half_dispatch_map,
|
||||||
torch.float32: float_dispatch_map,
|
torch.float32: float_dispatch_map,
|
||||||
}
|
}
|
||||||
|
if ColoAttention._flash_kernel_dispatch is None:
|
||||||
|
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
|
||||||
|
@ -118,7 +121,7 @@ class ColoAttention:
|
||||||
)
|
)
|
||||||
|
|
||||||
if size >= MEMORY_BOUND:
|
if size >= MEMORY_BOUND:
|
||||||
flash_kernel = FlashAttentionDaoLoader().load()
|
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
|
||||||
# lazy load
|
# lazy load
|
||||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||||
|
@ -126,7 +129,7 @@ class ColoAttention:
|
||||||
].load()
|
].load()
|
||||||
|
|
||||||
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
||||||
return flash_kernel
|
return ColoAttention._flash_kernel_dispatch
|
||||||
else:
|
else:
|
||||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue