pull/6061/head
wangbluo 2024-09-13 10:24:41 +00:00
parent 0ad3129cb9
commit b582319273
1 changed files with 1 additions and 0 deletions

View File

@ -210,6 +210,7 @@ class ColoAttention:
}
)
if is_causal:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if memory_size < MEMORY_BOUND:
if s_q != 1: