pull/6061/head
wangbluo 2024-09-13 05:06:56 +00:00
parent 683179cefd
commit 6eb8832366
1 changed files with 6 additions and 9 deletions

View File

@ -117,7 +117,7 @@ class ColoAttention:
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)
if size > MEMORY_BOUND:
if size >= MEMORY_BOUND:
FlashAttentionDaoLoader().load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
@ -173,7 +173,7 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
if memory_size > MEMORY_BOUND:
if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
@ -192,10 +192,10 @@ class ColoAttention:
b,
s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
if memory_size < MEMORY_BOUND and not is_causal:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@ -208,7 +208,7 @@ class ColoAttention:
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if memory_size > MEMORY_BOUND:
if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
if s_q != 1:
@ -303,9 +303,6 @@ class ColoAttention:
b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in (