diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3202ebf25..019a6b140 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -6,6 +6,7 @@ import torch.distributed import torch.distributed as dist import torch.nn.functional as F from einops import rearrange +from packaging import version from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, @@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function): max_seqlen_q = max_seqlen_kv = max_seqlen cu_seqlens_half = cu_seqlens // 2 max_seqlen_half = max_seqlen // 2 - misc_kwargs = { - "window_size": (-1, -1), "alibi_slopes": None, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "dropout_p": dropout_p, @@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function): "softcap": 0.0, "return_softmax": False, } + import flash_attn + + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + misc_kwargs["window_size_left"] = -1 + misc_kwargs["window_size_right"] = -1 + else: + misc_kwargs["window_size"] = (-1, -1) if ( RingAttention.HALF_INDICES is not None @@ -707,26 +713,39 @@ class RingAttention(torch.autograd.Function): # Helper to pass args to FA def _forward(q, k, v, causal): - ( - _, - _, - _, - _, - out, - softmax_lse, - _, - rng_state, - ) = _flash_attn_forward( - q, - k, - v, - cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, - cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, - max_seqlen_q if q.shape[0] == t else max_seqlen_half, - max_seqlen_kv if k.shape[0] == t else max_seqlen_half, - causal=causal, - **misc_kwargs, - ) + if version.parse(flash_attn.__version__) > version.parse("2.6.3"): + (out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) + else: + ( + _, + _, + _, + _, + out, + softmax_lse, + _, + rng_state, + ) = _flash_attn_forward( + q, + k, + v, + cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, + cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, + max_seqlen_q if q.shape[0] == t else max_seqlen_half, + max_seqlen_kv if k.shape[0] == t else max_seqlen_half, + causal=causal, + **misc_kwargs, + ) return out, softmax_lse, rng_state def _kv_comm(i):