mirror of https://github.com/hpcaitech/ColossalAI
updated flash attention usage
parent
085e7f4eff
commit
7bc0afc901
70
LICENSE
70
LICENSE
|
@ -326,3 +326,73 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
---------------- LICENSE FOR Flash Attention ----------------
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
---------------- LICENSE FOR Facebook xFormers ----------------
|
||||
|
||||
From xFormers:
|
||||
|
||||
Copyright (c) Facebook, Inc. and its affiliates
|
||||
|
||||
|
||||
===
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
"""
|
||||
The triton-based flash attention implementation is copied from the OpenAI/triton repository
|
||||
|
||||
You can find the repository in Triton https://github.com/openai/triton
|
||||
You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
|
||||
Reference:
|
||||
1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
|
||||
2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
|
||||
A general attention module using the flash attention kernels from xformers:
|
||||
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
|
||||
"""
|
||||
|
||||
import math
|
||||
|
@ -15,6 +9,159 @@ import subprocess
|
|||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
print('please install xformers from https://github.com/facebookresearch/xformers')
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from einops import rearrange
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
|
||||
|
||||
from .scaled_softmax import AttnMaskType
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
||||
ctx.save_for_backward(indices)
|
||||
# [b, s, ...]
|
||||
assert tensor.ndim >= 3
|
||||
ctx.bsz = tensor.shape[0]
|
||||
out = rearrange(tensor, 'b s ... -> (b s) ...')
|
||||
ctx.shape = out.shape
|
||||
# [1, ntokens, ...]
|
||||
return out[indices].unsqueeze(0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
# [b*s, ...]
|
||||
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
grad[indices] = grad_output.squeeze(0)
|
||||
grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
|
||||
# [b, s, ...]
|
||||
return grad, None
|
||||
|
||||
class Repad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
||||
ctx.save_for_backward(indices)
|
||||
# [ntokens, ...]
|
||||
tensor = tensor.squeeze(0)
|
||||
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
||||
# [b*s, ...]
|
||||
out[indices] = tensor
|
||||
# [b, s, ...]
|
||||
out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
# [b*s, ...]
|
||||
grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
|
||||
grad = grad_output[indices]
|
||||
# [1, ntokens, ...]
|
||||
return grad.unsqueeze(0), None, None, None
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0, \
|
||||
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
@staticmethod
|
||||
def get_seq_info_from_mask(attn_mask: torch.Tensor):
|
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
|
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
|
||||
return indices, seqlens
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
attn_bias = None
|
||||
if attn_mask_type == AttnMaskType.padding: # bert style
|
||||
assert attn_mask is not None, \
|
||||
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, \
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), " + \
|
||||
f"but got {attn_mask.dim()} dimensions."
|
||||
if tgt_len == src_len:
|
||||
q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
|
||||
kv_seqlen = None
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
|
||||
else:
|
||||
q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
|
||||
q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
|
||||
kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
elif attn_mask_type == AttnMaskType.causal: # gpt style
|
||||
attn_bias = LowerTriangularMask()
|
||||
|
||||
if bias is not None: # alibi / relative position emebedding
|
||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||
assert attn_mask_type == AttnMaskType.causal, \
|
||||
"attention with bias is only supported for causal attention so far."
|
||||
attn_bias = attn_bias.add_bias(bias)
|
||||
|
||||
out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
|
||||
|
||||
if attn_mask_type == AttnMaskType.padding and batch_size > 1:
|
||||
out = self.repad(out, q_indices, batch_size, tgt_len)
|
||||
|
||||
out = rearrange(out, 'b s h d -> b s (h d)')
|
||||
return out
|
||||
|
||||
|
||||
##########################################################################
|
||||
# the flash attention functions below that are copied
|
||||
# from the OpenAI/triton repository will be deprecated
|
||||
# You can find the repository in Triton https://github.com/openai/triton
|
||||
# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
# Reference:
|
||||
# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
|
||||
# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
|
||||
|
||||
|
||||
def triton_cuda_check():
|
||||
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
||||
|
@ -52,13 +199,6 @@ except ImportError:
|
|||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
print('please install xformers from https://github.com/facebookresearch/xformers')
|
||||
|
||||
if HAS_TRITON:
|
||||
# the following functions are adapted from the OpenAI Triton tutorial
|
||||
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
|
@ -422,25 +562,6 @@ if HAS_TRITON:
|
|||
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
class MaskedFlashAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = attention_head_size
|
||||
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
|
||||
attention_dropout=attention_dropout)
|
||||
|
||||
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
|
||||
if attention_mask.dtype is not torch.bool:
|
||||
attention_mask = attention_mask.bool()
|
||||
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
|
||||
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
return context
|
||||
|
||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
|
||||
"""
|
||||
Arguments:
|
||||
|
@ -511,20 +632,4 @@ if HAS_FLASH_ATTN:
|
|||
causal)
|
||||
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
|
||||
from einops import rearrange
|
||||
from xformers.ops.fmha import LowerTriangularMask
|
||||
|
||||
class MemoryEfficientAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
|
||||
super().__init__()
|
||||
attention_head_size = hidden_size // num_attention_heads
|
||||
self.scale = 1 / attention_head_size**0.5
|
||||
self.dropout = attention_dropout
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
|
||||
context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
return context
|
||||
##########################################################################
|
||||
|
|
|
@ -1,22 +1,13 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import (
|
||||
MaskedFlashAttention,
|
||||
flash_attention_q_k_v,
|
||||
flash_attention_q_kv,
|
||||
flash_attention_qkv,
|
||||
)
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
|
@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
|||
return ref_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
|
||||
# triton implementation
|
||||
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
|
||||
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
# reference implementation
|
||||
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
# flash implementation
|
||||
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
|
||||
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
|
||||
for i in range(3):
|
||||
if i == 0:
|
||||
tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
|
||||
elif i == 1:
|
||||
kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
|
||||
tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
|
||||
else:
|
||||
qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
|
||||
tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
|
||||
|
||||
tri_out.backward(dout, retain_graph=True)
|
||||
|
||||
if i == 0:
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
elif i == 1:
|
||||
tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
|
||||
tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
else:
|
||||
tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
|
||||
tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
|
||||
def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
|
||||
|
||||
qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
attention_mask = torch.randint(2, (Z, H)).cuda().bool()
|
||||
|
||||
out = attn(qkv, attention_mask)
|
||||
|
||||
dout = torch.rand_like(out)
|
||||
out.backward(dout)
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
|
||||
out = attn(q, k, v, attention_mask=LowerTriangularMask())
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
# attention mask of shape [B, S] with zero padding to max length S
|
||||
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
|
||||
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
|
||||
|
||||
dout = torch.rand_like(out)
|
||||
out.backward(dout)
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
|
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flash_attention(3, 4, 2, 16)
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
|
||||
x = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
qkv = c_attn(x)
|
||||
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
|
||||
y = attn(q, k, v)
|
||||
|
||||
assert list(y.shape) == [B, S, D]
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
|
||||
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
|
||||
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
|
||||
|
||||
attn = ColoAttention(D, H, dropout=0.1)
|
||||
|
||||
src = torch.randn((B, S, D), dtype=dtype, device="cuda")
|
||||
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
|
||||
|
||||
q = q_attn(tgt)
|
||||
kv = kv_attn(src)
|
||||
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
|
||||
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
|
||||
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
|
||||
|
||||
assert list(y.shape) == [B, T, D]
|
||||
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy)
|
||||
|
|
Loading…
Reference in New Issue