pull/6061/head
wangbluo 2024-09-13 06:00:58 +00:00
parent f393867cff
commit dc032172c3
1 changed files with 1 additions and 1 deletions

View File

@ -173,7 +173,7 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
if memory_size < MEMORY_BOUND and not is_causal:
if memory_size < MEMORY_BOUND:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)