[shardformer] hotfix attn mask (#5947)

pull/5951/head
Hongxin Liu 2024-07-29 19:10:06 +08:00 committed by GitHub
parent 9664b1bc19
commit 7b38964e3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 7 deletions

View File

@ -139,12 +139,11 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask = attention_mask.tril(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
assert q_padding_mask.shape == (
b,
s_q,
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
@ -156,7 +155,7 @@ class ColoAttention:
b,
s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@ -169,7 +168,8 @@ class ColoAttention:
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)