mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix flash attn window_size err (#6132)
* [fix] fix flash attn * [hotfix] fix flash-atten version * [fix] fix flash_atten version * [fix] fix flash-atten versions * [fix] fix flash-attn not enough values to unpack error * [fix] fix test_ring_attn * [fix] fix test ring attnpull/6142/head
parent
a2596519fd
commit
c2fe3137e2
|
@ -6,6 +6,7 @@ import torch.distributed
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionDaoLoader,
|
FlashAttentionDaoLoader,
|
||||||
|
@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function):
|
||||||
max_seqlen_q = max_seqlen_kv = max_seqlen
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||||
cu_seqlens_half = cu_seqlens // 2
|
cu_seqlens_half = cu_seqlens // 2
|
||||||
max_seqlen_half = max_seqlen // 2
|
max_seqlen_half = max_seqlen // 2
|
||||||
|
|
||||||
misc_kwargs = {
|
misc_kwargs = {
|
||||||
"window_size": (-1, -1),
|
|
||||||
"alibi_slopes": None,
|
"alibi_slopes": None,
|
||||||
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
||||||
"dropout_p": dropout_p,
|
"dropout_p": dropout_p,
|
||||||
|
@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function):
|
||||||
"softcap": 0.0,
|
"softcap": 0.0,
|
||||||
"return_softmax": False,
|
"return_softmax": False,
|
||||||
}
|
}
|
||||||
|
import flash_attn
|
||||||
|
|
||||||
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
||||||
|
misc_kwargs["window_size_left"] = -1
|
||||||
|
misc_kwargs["window_size_right"] = -1
|
||||||
|
else:
|
||||||
|
misc_kwargs["window_size"] = (-1, -1)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
RingAttention.HALF_INDICES is not None
|
RingAttention.HALF_INDICES is not None
|
||||||
|
@ -707,26 +713,39 @@ class RingAttention(torch.autograd.Function):
|
||||||
|
|
||||||
# Helper to pass args to FA
|
# Helper to pass args to FA
|
||||||
def _forward(q, k, v, causal):
|
def _forward(q, k, v, causal):
|
||||||
(
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
|
||||||
_,
|
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
|
||||||
_,
|
q,
|
||||||
_,
|
k,
|
||||||
_,
|
v,
|
||||||
out,
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||||
softmax_lse,
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||||
_,
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||||
rng_state,
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||||
) = _flash_attn_forward(
|
causal=causal,
|
||||||
q,
|
**misc_kwargs,
|
||||||
k,
|
)
|
||||||
v,
|
else:
|
||||||
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
(
|
||||||
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
_,
|
||||||
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
_,
|
||||||
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
_,
|
||||||
causal=causal,
|
_,
|
||||||
**misc_kwargs,
|
out,
|
||||||
)
|
softmax_lse,
|
||||||
|
_,
|
||||||
|
rng_state,
|
||||||
|
) = _flash_attn_forward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||||
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||||
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||||
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||||
|
causal=causal,
|
||||||
|
**misc_kwargs,
|
||||||
|
)
|
||||||
return out, softmax_lse, rng_state
|
return out, softmax_lse, rng_state
|
||||||
|
|
||||||
def _kv_comm(i):
|
def _kv_comm(i):
|
||||||
|
|
Loading…
Reference in New Issue