mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6061/head
parent
fdd84b9087
commit
216d54e374
|
@ -118,9 +118,11 @@ class FlashAttentionLoader(KernelLoader):
|
||||||
FlashAttentionSdpaCudaExtension,
|
FlashAttentionSdpaCudaExtension,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionDaoLoader(KernelLoader):
|
class FlashAttentionDaoLoader(KernelLoader):
|
||||||
REGISTRY = [FlashAttentionDaoCudaExtension]
|
REGISTRY = [FlashAttentionDaoCudaExtension]
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
class FlashAttentionWithCustomMaskLoader(KernelLoader):
|
||||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import (
|
||||||
|
FlashAttentionDaoLoader,
|
||||||
FlashAttentionForFloatAndCustomMaskLoader,
|
FlashAttentionForFloatAndCustomMaskLoader,
|
||||||
FlashAttentionLoader,
|
FlashAttentionLoader,
|
||||||
FlashAttentionDaoLoader,
|
|
||||||
FlashAttentionWithCustomMaskLoader,
|
FlashAttentionWithCustomMaskLoader,
|
||||||
KernelLoader,
|
KernelLoader,
|
||||||
)
|
)
|
||||||
|
@ -116,7 +116,7 @@ class ColoAttention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
if size > MEMORY_BOUND:
|
if size > MEMORY_BOUND:
|
||||||
FlashAttentionDaoLoader().load()
|
FlashAttentionDaoLoader().load()
|
||||||
# lazy load
|
# lazy load
|
||||||
|
@ -124,8 +124,10 @@ class ColoAttention:
|
||||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
|
||||||
mask_type
|
mask_type
|
||||||
].load()
|
].load()
|
||||||
|
|
||||||
return FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
return (
|
||||||
|
FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_attn_kwargs(
|
def prepare_attn_kwargs(
|
||||||
|
@ -204,15 +206,15 @@ class ColoAttention:
|
||||||
if invert:
|
if invert:
|
||||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||||
outputs["attention_mask"] = attention_mask
|
outputs["attention_mask"] = attention_mask
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
memory_size = (s_q * s_kv * element_size)
|
memory_size = s_q * s_kv * element_size
|
||||||
if memory_size > MEMORY_BOUND:
|
if memory_size > MEMORY_BOUND:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
outputs["attention_mask"] = attention_mask
|
outputs["attention_mask"] = attention_mask
|
||||||
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
|
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -297,11 +299,11 @@ class ColoAttention:
|
||||||
b, _, s_q, _ = q.shape
|
b, _, s_q, _ = q.shape
|
||||||
b, _, s_kv, _ = v.shape
|
b, _, s_kv, _ = v.shape
|
||||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||||
memory_size = (s_q * s_kv * element_size)
|
memory_size = s_q * s_kv * element_size
|
||||||
if memory_size > MEMORY_BOUND:
|
if memory_size > MEMORY_BOUND:
|
||||||
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
|
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
|
||||||
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
||||||
|
|
||||||
mask_type = attention_mask_type if attention_mask is not None else None
|
mask_type = attention_mask_type if attention_mask is not None else None
|
||||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||||
is_causal = attention_mask is not None and attention_mask_type in (
|
is_causal = attention_mask is not None and attention_mask_type in (
|
||||||
|
|
Loading…
Reference in New Issue