mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
6eb8832366
commit
f393867cff
|
@ -173,13 +173,13 @@ class ColoAttention:
|
||||||
# no padding
|
# no padding
|
||||||
assert is_causal
|
assert is_causal
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
if memory_size >= MEMORY_BOUND:
|
if memory_size < MEMORY_BOUND and not is_causal:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
attention_mask.tril_(diagonal=0)
|
attention_mask.tril_(diagonal=0)
|
||||||
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
||||||
if kv_padding_mask is None:
|
if kv_padding_mask is None:
|
||||||
|
@ -208,11 +208,11 @@ class ColoAttention:
|
||||||
)
|
)
|
||||||
if is_causal:
|
if is_causal:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
if memory_size >= MEMORY_BOUND:
|
if memory_size < MEMORY_BOUND:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||||
if invert:
|
if invert:
|
||||||
|
|
Loading…
Reference in New Issue