pull/6061/head
wangbluo 2024-09-13 05:06:56 +00:00
parent 683179cefd
commit 6eb8832366
1 changed files with 6 additions and 9 deletions

View File

@ -117,7 +117,7 @@ class ColoAttention:
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
) )
if size > MEMORY_BOUND: if size >= MEMORY_BOUND:
FlashAttentionDaoLoader().load() FlashAttentionDaoLoader().load()
# lazy load # lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
@ -173,7 +173,7 @@ class ColoAttention:
# no padding # no padding
assert is_causal assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL outputs["attention_mask_type"] = AttnMaskType.CAUSAL
if memory_size > MEMORY_BOUND: if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device) attention_mask = torch.empty((0,), dtype=dtype, device=device)
else: else:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device) attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
@ -192,10 +192,10 @@ class ColoAttention:
b, b,
s_kv, s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
if memory_size > MEMORY_BOUND: if memory_size < MEMORY_BOUND and not is_causal:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs.update( outputs.update(
{ {
"cu_seqlens_q": cu_seqlens_q, "cu_seqlens_q": cu_seqlens_q,
@ -208,7 +208,7 @@ class ColoAttention:
) )
if is_causal: if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if memory_size > MEMORY_BOUND: if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device) attention_mask = torch.empty((0,), dtype=dtype, device=device)
else: else:
if s_q != 1: if s_q != 1:
@ -303,9 +303,6 @@ class ColoAttention:
b, _, s_kv, _ = v.shape b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size() element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = s_q * s_kv * element_size memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
mask_type = attention_mask_type if attention_mask is not None else None mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size) attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in ( is_causal = attention_mask is not None and attention_mask_type in (