|
|
|
@ -18,7 +18,7 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
|
|
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
|
|
|
|
|
|
|
|
|
MEMORY_BOUND = 10 * 1e9
|
|
|
|
|
MEMORY_BOUND = 1 * 1e9
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
"AttnMaskType",
|
|
|
|
@ -125,9 +125,10 @@ class ColoAttention:
|
|
|
|
|
mask_type
|
|
|
|
|
].load()
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
FlashAttentionDaoLoader() if size > MEMORY_BOUND else ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
|
|
|
)
|
|
|
|
|
if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
|
|
|
|
|
return FlashAttentionDaoLoader()
|
|
|
|
|
else:
|
|
|
|
|
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def prepare_attn_kwargs(
|
|
|
|
@ -163,6 +164,8 @@ class ColoAttention:
|
|
|
|
|
return {}
|
|
|
|
|
assert len(shape_4d) == 4 and shape_4d[1] == 1
|
|
|
|
|
b, _, s_q, s_kv = shape_4d
|
|
|
|
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
memory_size = s_q * s_kv * element_size
|
|
|
|
|
outputs = {}
|
|
|
|
|
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
|
|
|
|
|
kv_padding_mask is None or kv_padding_mask.bool().all()
|
|
|
|
@ -170,10 +173,13 @@ class ColoAttention:
|
|
|
|
|
# no padding
|
|
|
|
|
assert is_causal
|
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
|
|
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
|
|
|
|
if s_q != 1:
|
|
|
|
|
attention_mask.tril_(diagonal=0)
|
|
|
|
|
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
|
|
|
|
if memory_size > MEMORY_BOUND:
|
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
|
|
|
else:
|
|
|
|
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
|
|
|
|
if s_q != 1:
|
|
|
|
|
attention_mask.tril_(diagonal=0)
|
|
|
|
|
attention_mask = attention_mask.expand(b, s_q, s_kv)
|
|
|
|
|
else:
|
|
|
|
|
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
|
|
|
|
|
if kv_padding_mask is None:
|
|
|
|
@ -186,7 +192,10 @@ class ColoAttention:
|
|
|
|
|
b,
|
|
|
|
|
s_kv,
|
|
|
|
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
|
|
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
|
|
|
if memory_size > MEMORY_BOUND:
|
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
|
|
|
else:
|
|
|
|
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
|
|
|
outputs.update(
|
|
|
|
|
{
|
|
|
|
|
"cu_seqlens_q": cu_seqlens_q,
|
|
|
|
@ -199,22 +208,16 @@ class ColoAttention:
|
|
|
|
|
)
|
|
|
|
|
if is_causal:
|
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
|
|
|
|
if s_q != 1:
|
|
|
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
|
|
|
|
if memory_size > MEMORY_BOUND:
|
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
|
|
|
else:
|
|
|
|
|
if s_q != 1:
|
|
|
|
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
|
|
|
|
else:
|
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
|
|
|
|
if invert:
|
|
|
|
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
|
|
|
|
outputs["attention_mask"] = attention_mask
|
|
|
|
|
|
|
|
|
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
memory_size = s_q * s_kv * element_size
|
|
|
|
|
if memory_size > MEMORY_BOUND:
|
|
|
|
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
|
|
|
outputs["attention_mask"] = attention_mask
|
|
|
|
|
if outputs["attention_mask_type"] != AttnMaskType.PADDED_CAUSAL:
|
|
|
|
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -301,7 +304,6 @@ class ColoAttention:
|
|
|
|
|
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
|
|
|
|
memory_size = s_q * s_kv * element_size
|
|
|
|
|
if memory_size > MEMORY_BOUND:
|
|
|
|
|
attention_mask = torch.empty((0,), dtype=q.dtype, device=q.device)
|
|
|
|
|
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
|
|
|
|
|
|
|
|
|
mask_type = attention_mask_type if attention_mask is not None else None
|
|
|
|
|