mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
683179cefd
commit
6eb8832366
|
@ -117,7 +117,7 @@ class ColoAttention:
|
||||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
if size > MEMORY_BOUND:
|
if size >= MEMORY_BOUND:
|
||||||
FlashAttentionDaoLoader().load()
|
FlashAttentionDaoLoader().load()
|
||||||
# lazy load
|
# lazy load
|
||||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
|
||||||
|
@ -173,7 +173,7 @@ class ColoAttention:
|
||||||
# no padding
|
# no padding
|
||||||
assert is_causal
|
assert is_causal
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
if memory_size > MEMORY_BOUND:
|
if memory_size >= MEMORY_BOUND:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||||
|
@ -192,10 +192,10 @@ class ColoAttention:
|
||||||
b,
|
b,
|
||||||
s_kv,
|
s_kv,
|
||||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||||
if memory_size > MEMORY_BOUND:
|
if memory_size < MEMORY_BOUND and not is_causal:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"cu_seqlens_q": cu_seqlens_q,
|
"cu_seqlens_q": cu_seqlens_q,
|
||||||
|
@ -208,7 +208,7 @@ class ColoAttention:
|
||||||
)
|
)
|
||||||
if is_causal:
|
if is_causal:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
if memory_size > MEMORY_BOUND:
|
if memory_size >= MEMORY_BOUND:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
|
@ -303,9 +303,6 @@ class ColoAttention:
|
||||||
b, _, s_kv, _ = v.shape
|
b, _, s_kv, _ = v.shape
|
||||||
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
element_size = torch.tensor([], dtype=q.dtype).element_size()
|
||||||
memory_size = s_q * s_kv * element_size
|
memory_size = s_q * s_kv * element_size
|
||||||
if memory_size > MEMORY_BOUND:
|
|
||||||
assert attention_mask_type == AttnMaskType.PADDED_CAUSAL or AttnMaskType.PADDED
|
|
||||||
|
|
||||||
mask_type = attention_mask_type if attention_mask is not None else None
|
mask_type = attention_mask_type if attention_mask is not None else None
|
||||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
|
||||||
is_causal = attention_mask is not None and attention_mask_type in (
|
is_causal = attention_mask is not None and attention_mask_type in (
|
||||||
|
|
Loading…
Reference in New Issue