Browse Source

[fix] fix attn

pull/6114/head
duanjunwen 1 week ago
parent
commit
014afbdb59
  1. 23
      colossalai/shardformer/layer/attn.py

23
colossalai/shardformer/layer/attn.py

@ -6,6 +6,7 @@ import torch.distributed
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from packaging import version
from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function):
max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2
max_seqlen_half = max_seqlen // 2
misc_kwargs = {
"window_size": (-1, -1),
"alibi_slopes": None,
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
"dropout_p": dropout_p,
@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function):
"softcap": 0.0,
"return_softmax": False,
}
import flash_attn
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
misc_kwargs["window_size_left"] = -1
misc_kwargs["window_size_right"] = -1
else:
misc_kwargs["window_size"] = (-1, -1)
if (
RingAttention.HALF_INDICES is not None
@ -707,6 +713,19 @@ class RingAttention(torch.autograd.Function):
# Helper to pass args to FA
def _forward(q, k, v, causal):
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
q,
k,
v,
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
causal=causal,
**misc_kwargs,
)
else:
(
_,
_,

Loading…
Cancel
Save