pull/6061/head
wangbluo 2024-09-13 05:24:52 +00:00
parent 6eb8832366
commit f393867cff
1 changed files with 6 additions and 6 deletions

View File

@ -173,13 +173,13 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
if memory_size < MEMORY_BOUND and not is_causal:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
@ -208,11 +208,11 @@ class ColoAttention:
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if memory_size >= MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
if memory_size < MEMORY_BOUND:
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
if invert: