updated flash attention usage

pull/3185/head
zbian 2023-03-17 15:09:47 +08:00 committed by アマデウス
parent 085e7f4eff
commit 7bc0afc901
3 changed files with 302 additions and 165 deletions

70
LICENSE
View File

@ -326,3 +326,73 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR Flash Attention ----------------
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
---------------- LICENSE FOR Facebook xFormers ----------------
From xFormers:
Copyright (c) Facebook, Inc. and its affiliates
===
BSD 3-Clause License
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,12 +1,6 @@
"""
The triton-based flash attention implementation is copied from the OpenAI/triton repository
You can find the repository in Triton https://github.com/openai/triton
You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
Reference:
1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
import math
@ -15,6 +9,159 @@ import subprocess
import torch
try:
from xformers.ops.fmha import memory_efficient_attention
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
if HAS_MEM_EFF_ATTN:
from typing import Optional
from einops import rearrange
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
from .scaled_softmax import AttnMaskType
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, 'b s ... -> (b s) ...')
ctx.shape = out.shape
# [1, ntokens, ...]
return out[indices].unsqueeze(0)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# [b*s, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output.squeeze(0)
grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor.squeeze(0)
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
# [b, s, ...]
out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
return out
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# [b*s, ...]
grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
grad = grad_output[indices]
# [1, ntokens, ...]
return grad.unsqueeze(0), None, None, None
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert embed_dim % num_heads == 0, \
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
@staticmethod
def get_seq_info_from_mask(attn_mask: torch.Tensor):
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
return indices, seqlens
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None):
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
attn_bias = None
if attn_mask_type == AttnMaskType.padding: # bert style
assert attn_mask is not None, \
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, \
"attention mask is supposed to have shape (batch_size, seq_len), " + \
f"but got {attn_mask.dim()} dimensions."
if tgt_len == src_len:
q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
kv_seqlen = None
if batch_size > 1:
query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
else:
q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position emebedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert attn_mask_type == AttnMaskType.causal, \
"attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
if attn_mask_type == AttnMaskType.padding and batch_size > 1:
out = self.repad(out, q_indices, batch_size, tgt_len)
out = rearrange(out, 'b s h d -> b s (h d)')
return out
##########################################################################
# the flash attention functions below that are copied
# from the OpenAI/triton repository will be deprecated
# You can find the repository in Triton https://github.com/openai/triton
# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
# Reference:
# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
def triton_cuda_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
@ -52,13 +199,6 @@ except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
try:
from xformers.ops.fmha import memory_efficient_attention
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
if HAS_TRITON:
# the following functions are adapted from the OpenAI Triton tutorial
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
@ -422,25 +562,6 @@ if HAS_TRITON:
if HAS_FLASH_ATTN:
from einops import rearrange
class MaskedFlashAttention(torch.nn.Module):
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
attention_dropout=attention_dropout)
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
if attention_mask.dtype is not torch.bool:
attention_mask = attention_mask.bool()
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
@ -511,20 +632,4 @@ if HAS_FLASH_ATTN:
causal)
if HAS_MEM_EFF_ATTN:
from einops import rearrange
from xformers.ops.fmha import LowerTriangularMask
class MemoryEfficientAttention(torch.nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
super().__init__()
attention_head_size = hidden_size // num_attention_heads
self.scale = 1 / attention_head_size**0.5
self.dropout = attention_dropout
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
##########################################################################

View File

@ -1,22 +1,13 @@
import random
import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import (
MaskedFlashAttention,
flash_attention_q_k_v,
flash_attention_q_kv,
flash_attention_qkv,
)
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
if HAS_MEM_EFF_ATTN:
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
# triton implementation
tri_out = triton_flash_attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
assert list(y.shape) == [B, S, D]
# reference implementation
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# flash implementation
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
for i in range(3):
if i == 0:
tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
elif i == 1:
kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
else:
qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
tri_out.backward(dout, retain_graph=True)
if i == 0:
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
elif i == 1:
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
else:
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
attention_mask = torch.randint(2, (Z, H)).cuda().bool()
out = attn(qkv, attention_mask)
dout = torch.rand_like(out)
out.backward(dout)
dy = torch.rand_like(y)
y.backward(dy)
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
out = attn(q, k, v, attention_mask=LowerTriangularMask())
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
# attention mask of shape [B, S] with zero padding to max length S
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
dout = torch.rand_like(out)
out.backward(dout)
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
assert list(y.shape) == [B, S, D]
dy = torch.rand_like(y)
y.backward(dy)
if __name__ == '__main__':
test_flash_attention(3, 4, 2, 16)
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
y = attn(q, k, v)
assert list(y.shape) == [B, S, D]
dy = torch.rand_like(y)
y.backward(dy)
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
src = torch.randn((B, S, D), dtype=dtype, device="cuda")
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
q = q_attn(tgt)
kv = kv_attn(src)
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
assert list(y.shape) == [B, T, D]
dy = torch.rand_like(y)
y.backward(dy)