mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6061/head
parent
fdd84b9087
commit
216d54e374
|
@ -118,9 +118,11 @@ class FlashAttentionLoader(KernelLoader):
|
|||
FlashAttentionSdpaCudaExtension,
|
||||
]
|
||||
|
||||
|
||||
class FlashAttentionDaoLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionDaoCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||
|
||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
|||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionDaoLoader,
|
||||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionDaoLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
KernelLoader,
|
||||
)
|
||||
|
@ -125,7 +125,9 @@ class ColoAttention:
|
|||
mask_type
|
||||
].load()
|
||||
|
||||
return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
return (
|
||||
FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_attn_kwargs(
|
||||
|
@ -206,7 +208,7 @@ class ColoAttention:
|
|||
outputs["attention_mask"] = attention_mask
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_size = (s_q * s_kv * element_size)
|
||||
memory_size = s_q * s_kv * element_size
|
||||
if memory_size > MEMORY_BOUND:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
outputs["attention_mask"] = attention_mask
|
||||
|
@ -297,7 +299,7 @@ class ColoAttention:
|
|||
b, _, s_q, _ = q.shape
|
||||
b, _, s_kv, _ = v.shape
|
||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||
memory_size = (s_q * s_kv * element_size)
|
||||
memory_size = s_q * s_kv * element_size
|
||||
if memory_size > MEMORY_BOUND:
|
||||
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
|
||||
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
||||
|
|
Loading…
Reference in New Issue