|
|
|
@ -6,6 +6,7 @@ import torch.distributed
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
from einops import rearrange |
|
|
|
|
from packaging import version |
|
|
|
|
|
|
|
|
|
from colossalai.kernel.kernel_loader import ( |
|
|
|
|
FlashAttentionDaoLoader, |
|
|
|
@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
max_seqlen_q = max_seqlen_kv = max_seqlen |
|
|
|
|
cu_seqlens_half = cu_seqlens // 2 |
|
|
|
|
max_seqlen_half = max_seqlen // 2 |
|
|
|
|
|
|
|
|
|
misc_kwargs = { |
|
|
|
|
"window_size": (-1, -1), |
|
|
|
|
"alibi_slopes": None, |
|
|
|
|
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, |
|
|
|
|
"dropout_p": dropout_p, |
|
|
|
@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
"softcap": 0.0, |
|
|
|
|
"return_softmax": False, |
|
|
|
|
} |
|
|
|
|
import flash_attn |
|
|
|
|
|
|
|
|
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"): |
|
|
|
|
misc_kwargs["window_size_left"] = -1 |
|
|
|
|
misc_kwargs["window_size_right"] = -1 |
|
|
|
|
else: |
|
|
|
|
misc_kwargs["window_size"] = (-1, -1) |
|
|
|
|
|
|
|
|
|
if ( |
|
|
|
|
RingAttention.HALF_INDICES is not None |
|
|
|
@ -707,6 +713,19 @@ class RingAttention(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
# Helper to pass args to FA |
|
|
|
|
def _forward(q, k, v, causal): |
|
|
|
|
if version.parse(flash_attn.__version__) > version.parse("2.6.3"): |
|
|
|
|
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward( |
|
|
|
|
q, |
|
|
|
|
k, |
|
|
|
|
v, |
|
|
|
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half, |
|
|
|
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half, |
|
|
|
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half, |
|
|
|
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half, |
|
|
|
|
causal=causal, |
|
|
|
|
**misc_kwargs, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
( |
|
|
|
|
_, |
|
|
|
|
_, |
|
|
|
|