fix the attn

pull/6061/head
wangbluo 3 months ago
parent 216d54e374
commit 0a01e2a453

@ -18,7 +18,7 @@ from colossalai.logging import get_dist_logger
from .utils import RingComm, get_half_index, split_varlen_zigzag
MEMORY_BOUND = 10 * 1e9
MEMORY_BOUND = 1 * 1e9
__all__ = [
"AttnMaskType",
@ -125,9 +125,10 @@ class ColoAttention:
mask_type
].load()
return (
FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
)
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return FlashAttentionDaoLoader()
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
@staticmethod
def prepare_attn_kwargs(
@ -163,6 +164,8 @@ class ColoAttention:
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = s_q * s_kv * element_size
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
@ -170,10 +173,13 @@ class ColoAttention:
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
@ -186,7 +192,10 @@ class ColoAttention:
b,
s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
@ -199,22 +208,16 @@ class ColoAttention:
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
outputs["attention_mask"] = attention_mask
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
return outputs
@staticmethod
@ -301,7 +304,6 @@ class ColoAttention:
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = s_q * s_kv * element_size
if memory_size > MEMORY_BOUND:
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
mask_type = attention_mask_type if attention_mask is not None else None

Loading…
Cancel
Save