[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6061/head
pre-commit-ci[bot] 2024-09-13 02:38:39 +00:00
parent fdd84b9087
commit 216d54e374
2 changed files with 13 additions and 9 deletions

View File

@ -118,9 +118,11 @@ class FlashAttentionLoader(KernelLoader):
FlashAttentionSdpaCudaExtension,
]
class FlashAttentionDaoLoader(KernelLoader):
REGISTRY = [FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

View File

@ -8,9 +8,9 @@ import torch.nn.functional as F
from einops import rearrange
from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionDaoLoader,
FlashAttentionWithCustomMaskLoader,
KernelLoader,
)
@ -116,7 +116,7 @@ class ColoAttention:
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)
if size > MEMORY_BOUND:
FlashAttentionDaoLoader().load()
# lazy load
@ -124,8 +124,10 @@ class ColoAttention:
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
return (
FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
)
@staticmethod
def prepare_attn_kwargs(
@ -204,15 +206,15 @@ class ColoAttention:
if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = (s_q * s_kv * element_size)
memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs["attention_mask"] = attention_mask
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
return outputs
@staticmethod
@ -297,11 +299,11 @@ class ColoAttention:
b, _, s_q, _ = q.shape
b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = (s_q * s_kv * element_size)
memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in (