mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
683179cefd
commit
6eb8832366
|
@ -117,7 +117,7 @@ class ColoAttention:
|
|||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||
)
|
||||
|
||||
if size > MEMORY_BOUND:
|
||||
if size >= MEMORY_BOUND:
|
||||
FlashAttentionDaoLoader().load()
|
||||
# lazy load
|
||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||
|
@ -173,7 +173,7 @@ class ColoAttention:
|
|||
# no padding
|
||||
assert is_causal
|
||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||
if memory_size > MEMORY_BOUND:
|
||||
if memory_size >= MEMORY_BOUND:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||
|
@ -192,10 +192,10 @@ class ColoAttention:
|
|||
b,
|
||||
s_kv,
|
||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||
if memory_size > MEMORY_BOUND:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
if memory_size < MEMORY_BOUND and not is_causal:
|
||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||
else:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
outputs.update(
|
||||
{
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
|
@ -208,7 +208,7 @@ class ColoAttention:
|
|||
)
|
||||
if is_causal:
|
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||
if memory_size > MEMORY_BOUND:
|
||||
if memory_size >= MEMORY_BOUND:
|
||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||
else:
|
||||
if s_q != 1:
|
||||
|
@ -303,9 +303,6 @@ class ColoAttention:
|
|||
b, _, s_kv, _ = v.shape
|
||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||
memory_size = s_q * s_kv * element_size
|
||||
if memory_size > MEMORY_BOUND:
|
||||
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
||||
|
||||
mask_type = attention_mask_type if attention_mask is not None else None
|
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||
is_causal = attention_mask is not None and attention_mask_type in (
|
||||
|
|
Loading…
Reference in New Issue