pull/6061/head
wangbluo 2024-09-14 10:40:35 +00:00
parent b582319273
commit 827ef3ee9a
1 changed files with 4 additions and 5 deletions

View File

@ -195,10 +195,6 @@ class ColoAttention:
b,
s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
if memory_size < MEMORY_BOUND and not is_causal:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@ -210,15 +206,18 @@ class ColoAttention:
}
)
if is_causal:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if memory_size < MEMORY_BOUND:
if s_q != 1:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
if memory_size < MEMORY_BOUND:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask