mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] hotfix attn mask (#5947)
parent
9664b1bc19
commit
7b38964e3a
|
@ -139,12 +139,11 @@ class ColoAttention:
|
|||
# no padding
|
||||
assert is_causal
|
||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||
if s_q != 1:
|
||||
attention_mask = attention_mask.tril(diagonal=0)
|
||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||
else:
|
||||
assert q_padding_mask.shape == (
|
||||
b,
|
||||
s_q,
|
||||
), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||
if kv_padding_mask is None:
|
||||
# self attention
|
||||
|
@ -156,7 +155,7 @@ class ColoAttention:
|
|||
b,
|
||||
s_kv,
|
||||
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
||||
attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
@ -169,7 +168,8 @@ class ColoAttention:
|
|||
)
|
||||
if is_causal:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
if s_q != 1:
|
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||
else:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||
|
|
Loading…
Reference in New Issue