mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer testspull/5517/head
Hongxin Liu
8 months ago
committed by
GitHub
45 changed files with 2538 additions and 1165 deletions
@ -1,209 +0,0 @@
|
||||
import enum |
||||
import math |
||||
import warnings |
||||
from dataclasses import dataclass |
||||
from typing import Iterable, Optional, Tuple |
||||
|
||||
import torch |
||||
import torch.nn.functional as F |
||||
from einops import rearrange |
||||
|
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.kernel.kernel_loader import FlashAttentionLoader |
||||
|
||||
|
||||
@dataclass |
||||
class SeqLenInfo: |
||||
seqlens: Iterable[int] = None |
||||
indices: torch.Tensor = None |
||||
max_seqlen: int = None |
||||
cu_seqlens: torch.Tensor = None |
||||
|
||||
@staticmethod |
||||
def materialize( |
||||
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() |
||||
): |
||||
if attn_mask is not None: |
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) |
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() |
||||
else: |
||||
batch_size, tgt_len = size[0], size[1] |
||||
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) |
||||
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) |
||||
max_seqlen = max(seqlens) |
||||
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) |
||||
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) |
||||
|
||||
|
||||
class AttnMaskType(enum.Enum): |
||||
padding = 1 |
||||
causal = 2 |
||||
paddedcausal = 3 |
||||
|
||||
|
||||
class Unpad(torch.autograd.Function): |
||||
""" |
||||
Adapted from |
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): |
||||
ctx.save_for_backward(indices) |
||||
# [b, s, ...] |
||||
assert tensor.ndim >= 3 |
||||
ctx.bsz = tensor.shape[0] |
||||
out = rearrange(tensor, "b s ... -> (b s) ...") |
||||
ctx.shape = out.shape |
||||
# [ntokens, ...] |
||||
return out[indices] |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
(indices,) = ctx.saved_tensors |
||||
# [ntokens, ...] |
||||
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) |
||||
grad[indices] = grad_output |
||||
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) |
||||
# [b, s, ...] |
||||
return grad, None |
||||
|
||||
|
||||
class Repad(torch.autograd.Function): |
||||
""" |
||||
Adapted from |
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): |
||||
ctx.save_for_backward(indices) |
||||
# [ntokens, ...] |
||||
tensor = tensor |
||||
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) |
||||
# [b*s, ...] |
||||
out[indices] = tensor |
||||
return out |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
(indices,) = ctx.saved_tensors |
||||
# [b*s, ...] |
||||
grad = grad_output[indices] |
||||
# [ntokens, ...] |
||||
return grad, None, None, None |
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module): |
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): |
||||
super().__init__() |
||||
assert ( |
||||
embed_dim % num_heads == 0 |
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." |
||||
if scale is not None: |
||||
self.scale = scale |
||||
else: |
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads) |
||||
self.dropout = dropout |
||||
|
||||
self.attn = FlashAttentionLoader().load() |
||||
|
||||
@staticmethod |
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: |
||||
return Unpad.apply(tensor, indices) |
||||
|
||||
@staticmethod |
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: |
||||
return Repad.apply(tensor, indices, batch_size, seq_len) |
||||
|
||||
def forward( |
||||
self, |
||||
query: torch.Tensor, |
||||
key: torch.Tensor, |
||||
value: torch.Tensor, |
||||
attn_mask: Optional[torch.Tensor] = None, |
||||
origin_attn_mask: Optional[torch.Tensor] = None, |
||||
attn_mask_type: Optional[AttnMaskType] = None, |
||||
bias: Optional[torch.Tensor] = None, |
||||
): |
||||
""" |
||||
ColoAttention |
||||
|
||||
Args: |
||||
q: (batch, q_seqlen, nheads, headdim) |
||||
k: (batch, kv_seqlen, nheads, headdim) |
||||
v: (batch, kv_seqlen, nheads, headdim) |
||||
origin_attn_mask: (nheads, q_seqlen, kv_seqlen) |
||||
bias: will not be used |
||||
Return: |
||||
attn_out: (batch, q_seqlen, nheads, headdim). |
||||
""" |
||||
# if flash attention is not applicable, switch to memory effcient attention |
||||
if self.attn.__name__ == "flash_attention" and ( |
||||
query.dtype not in [torch.float16, torch.bfloat16] or bias != None |
||||
): |
||||
warnings.warn( |
||||
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." |
||||
) |
||||
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") |
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 |
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1 |
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] |
||||
# unpad |
||||
seq_len_info_q = None |
||||
seq_len_info_kv = None |
||||
if padded: |
||||
# bert style, unpad process |
||||
assert ( |
||||
attn_mask is not None |
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." |
||||
assert attn_mask.dim() == 2, ( |
||||
"attention mask is supposed to have shape (batch_size, seq_len), " |
||||
+ f"but got {attn_mask.dim()} dimensions." |
||||
) |
||||
|
||||
# bert style |
||||
if tgt_len == src_len: |
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) |
||||
if batch_size > 1: |
||||
query, key, value = self.unpad( |
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices |
||||
).unbind(dim=1) |
||||
else: |
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) |
||||
seq_len_info_kv = seq_len_info_q |
||||
else: |
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) |
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) |
||||
if batch_size > 1: |
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1) |
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( |
||||
dim=1 |
||||
) |
||||
else: |
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) |
||||
|
||||
out = self.attn( |
||||
query, |
||||
key, |
||||
value, |
||||
seq_len_info_q=seq_len_info_q, |
||||
seq_len_info_kv=seq_len_info_kv, |
||||
origin_attn_mask=origin_attn_mask, |
||||
dropout_p=self.dropout, |
||||
scale=self.scale, |
||||
causal=causal, |
||||
padded=padded, |
||||
) |
||||
|
||||
# repad |
||||
if padded: |
||||
if batch_size > 1: |
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) |
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) |
||||
|
||||
if len(out.shape) == 4: |
||||
out = rearrange(out, "b s h d -> b s (h d)") |
||||
return out |
@ -0,0 +1,269 @@
|
||||
from enum import Enum |
||||
from typing import Callable, Dict, Optional, Tuple |
||||
|
||||
import torch |
||||
import torch.nn.functional as F |
||||
|
||||
from colossalai.kernel.kernel_loader import ( |
||||
FlashAttentionForFloatAndCustomMaskLoader, |
||||
FlashAttentionLoader, |
||||
FlashAttentionWithCustomMaskLoader, |
||||
FlashAttentionWithPaddingMaskLoader, |
||||
KernelLoader, |
||||
) |
||||
|
||||
__all__ = [ |
||||
"AttnMaskType", |
||||
"ColoAttention", |
||||
] |
||||
|
||||
|
||||
class AttnMaskType(Enum): |
||||
CUSTOM = 0 |
||||
PADDED = 1 |
||||
CAUSAL = 2 |
||||
PADDED_CAUSAL = 3 |
||||
|
||||
|
||||
def invert_mask(mask: torch.Tensor) -> torch.Tensor: |
||||
"""Invert the mask tensor. |
||||
|
||||
Args: |
||||
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] |
||||
|
||||
Returns: |
||||
torch.Tensor: Inverted mask tensor. |
||||
""" |
||||
inverted_mask = 1.0 - mask |
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) |
||||
|
||||
|
||||
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
||||
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: |
||||
"""Get padding information from padding mask. |
||||
|
||||
Args: |
||||
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] |
||||
|
||||
Returns: |
||||
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) |
||||
""" |
||||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) |
||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() |
||||
max_seqlen_in_batch = seqlens_in_batch.max().item() |
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
||||
return max_seqlen_in_batch, cu_seqlens, indices |
||||
|
||||
|
||||
class ColoAttention: |
||||
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None |
||||
|
||||
@staticmethod |
||||
def _init_kernels_dispatch(): |
||||
if ColoAttention._kernel_dispatch_map is None: |
||||
# fp16/bf16 |
||||
half_dispatch_map = { |
||||
None: FlashAttentionLoader(), |
||||
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), |
||||
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), |
||||
AttnMaskType.CAUSAL: FlashAttentionLoader(), |
||||
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), |
||||
} |
||||
# fp32 |
||||
float_dispatch_map = { |
||||
None: FlashAttentionForFloatAndCustomMaskLoader(), |
||||
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), |
||||
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), |
||||
} |
||||
ColoAttention._kernel_dispatch_map = { |
||||
torch.float16: half_dispatch_map, |
||||
torch.bfloat16: half_dispatch_map, |
||||
torch.float32: float_dispatch_map, |
||||
} |
||||
|
||||
@staticmethod |
||||
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: |
||||
ColoAttention._init_kernels_dispatch() |
||||
if ( |
||||
dtype not in ColoAttention._kernel_dispatch_map |
||||
or mask_type not in ColoAttention._kernel_dispatch_map[dtype] |
||||
): |
||||
raise ValueError( |
||||
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) |
||||
) |
||||
# lazy load |
||||
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): |
||||
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ |
||||
mask_type |
||||
].load() |
||||
return ColoAttention._kernel_dispatch_map[dtype][mask_type] |
||||
|
||||
@staticmethod |
||||
def prepare_attn_kwargs( |
||||
shape_4d: Tuple[int], |
||||
dtype: torch.dtype, |
||||
device: torch.device, |
||||
q_padding_mask: Optional[torch.Tensor] = None, |
||||
kv_padding_mask: Optional[torch.Tensor] = None, |
||||
is_causal: bool = False, |
||||
) -> Dict[str, torch.Tensor]: |
||||
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type. |
||||
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. |
||||
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. |
||||
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. |
||||
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. |
||||
|
||||
Args: |
||||
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) |
||||
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` |
||||
device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` |
||||
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. |
||||
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. |
||||
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. |
||||
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. |
||||
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. |
||||
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. |
||||
|
||||
Returns: |
||||
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. |
||||
""" |
||||
if q_padding_mask is None and not is_causal: |
||||
return {} |
||||
assert len(shape_4d) == 4 and shape_4d[1] == 1 |
||||
b, _, s_q, s_kv = shape_4d |
||||
outputs = {} |
||||
if (q_padding_mask is None or q_padding_mask.bool().all()) and ( |
||||
kv_padding_mask is None or kv_padding_mask.bool().all() |
||||
): |
||||
# no padding |
||||
assert is_causal |
||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL |
||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) |
||||
else: |
||||
if kv_padding_mask is None: |
||||
# self attention |
||||
kv_padding_mask = q_padding_mask |
||||
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( |
||||
b, |
||||
s_kv, |
||||
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" |
||||
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) |
||||
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) |
||||
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) |
||||
outputs.update( |
||||
{ |
||||
"cu_seqlens_q": cu_seqlens_q, |
||||
"cu_seqlens_kv": cu_seqlens_kv, |
||||
"max_seqlen_q": max_seqlen_q, |
||||
"max_seqlen_kv": max_seqlen_kv, |
||||
"q_indices": q_indices, |
||||
"kv_indices": kv_indices, |
||||
} |
||||
) |
||||
if is_causal: |
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL |
||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) |
||||
else: |
||||
outputs["attention_mask_type"] = AttnMaskType.PADDED |
||||
attention_mask = invert_mask(attention_mask).unsqueeze(1) |
||||
outputs["attention_mask"] = attention_mask |
||||
return outputs |
||||
|
||||
@staticmethod |
||||
def attention( |
||||
q: torch.Tensor, |
||||
k: torch.Tensor, |
||||
v: torch.Tensor, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, |
||||
cu_seqlens_q: Optional[torch.Tensor] = None, |
||||
cu_seqlens_kv: Optional[torch.Tensor] = None, |
||||
max_seqlen_q: Optional[int] = None, |
||||
max_seqlen_kv: Optional[int] = None, |
||||
q_indices: Optional[torch.Tensor] = None, |
||||
kv_indices: Optional[torch.Tensor] = None, |
||||
dropout_p: float = 0.0, |
||||
scale: Optional[float] = None, |
||||
) -> torch.Tensor: |
||||
"""Flash Attention function. It supports 4 mask type. |
||||
1. custom mask: recv attention_mask |
||||
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices |
||||
3. causal mask: recv attention_mask, attention_mask_type |
||||
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices |
||||
|
||||
Args: |
||||
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] |
||||
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] |
||||
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] |
||||
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. |
||||
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. |
||||
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths |
||||
of the sequences in the batch, used to index into q. |
||||
Shape should be [B+1]. Defaults to None. |
||||
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths |
||||
of the sequences in the batch, used to index into kv. |
||||
Shape should be [B+1]. Defaults to None. |
||||
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. |
||||
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. |
||||
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. |
||||
Shape should be [NUM_TOKENS]. Defaults to None. |
||||
dropout_p (float, optional): Dropout probability. Defaults to 0.0. |
||||
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. |
||||
|
||||
Returns: |
||||
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] |
||||
""" |
||||
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan |
||||
# this case is usaul when padding mask is used and self attention is performed |
||||
# thus, we don't use sdpa when padding mask is used |
||||
# sanity check |
||||
if attention_mask is not None: |
||||
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." |
||||
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): |
||||
assert ( |
||||
cu_seqlens_q is None |
||||
and cu_seqlens_kv is None |
||||
and max_seqlen_q is None |
||||
and max_seqlen_kv is None |
||||
and q_indices is None |
||||
and kv_indices is None |
||||
) |
||||
if attention_mask_type == AttnMaskType.CUSTOM: |
||||
assert not torch.all(attention_mask != 0, dim=-1).any() |
||||
elif attention_mask_type in ( |
||||
AttnMaskType.PADDED, |
||||
AttnMaskType.PADDED_CAUSAL, |
||||
): |
||||
assert ( |
||||
cu_seqlens_q is not None |
||||
and cu_seqlens_kv is not None |
||||
and max_seqlen_q is not None |
||||
and max_seqlen_kv is not None |
||||
and q_indices is not None |
||||
and kv_indices is not None |
||||
) |
||||
else: |
||||
# if attention_mask is None, attention_mask_type should be the default value |
||||
assert attention_mask_type == AttnMaskType.CUSTOM |
||||
# kernel dispatch |
||||
mask_type = attention_mask_type if attention_mask is not None else None |
||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) |
||||
is_causal = attention_mask is not None and attention_mask_type in ( |
||||
AttnMaskType.CAUSAL, |
||||
AttnMaskType.PADDED_CAUSAL, |
||||
) |
||||
return attn_func( |
||||
q, |
||||
k, |
||||
v, |
||||
dropout_p=dropout_p, |
||||
scale=scale, |
||||
attention_mask=attention_mask, |
||||
is_causal=is_causal, |
||||
cu_seqlens_q=cu_seqlens_q, |
||||
cu_seqlens_kv=cu_seqlens_kv, |
||||
max_seqlen_q=max_seqlen_q, |
||||
max_seqlen_kv=max_seqlen_kv, |
||||
q_indices=q_indices, |
||||
kv_indices=kv_indices, |
||||
) |
@ -1,20 +1,14 @@
|
||||
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension |
||||
from .flash_attention_npu import FlashAttentionNpuExtension |
||||
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension |
||||
from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension |
||||
|
||||
try: |
||||
# TODO: remove this after updating openmoe example |
||||
import flash_attention # noqa |
||||
|
||||
HAS_FLASH_ATTN = True |
||||
except: |
||||
HAS_FLASH_ATTN = False |
||||
|
||||
try: |
||||
import xformers # noqa |
||||
|
||||
HAS_MEM_EFF_ATTN = True |
||||
except: |
||||
HAS_MEM_EFF_ATTN = False |
||||
|
||||
|
||||
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] |
||||
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"] |
||||
|
@ -0,0 +1,56 @@
|
||||
from ..base_extension import _Extension |
||||
|
||||
|
||||
class FlashAttentionSdpaCudaExtension(_Extension): |
||||
def __init__(self): |
||||
super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False) |
||||
|
||||
def is_available(self) -> bool: |
||||
# cuda extension can only be built if cuda is available |
||||
try: |
||||
import torch |
||||
|
||||
cuda_available = torch.cuda.is_available() |
||||
except: |
||||
cuda_available = False |
||||
return cuda_available |
||||
|
||||
def assert_compatible(self) -> bool: |
||||
pass |
||||
|
||||
def build_aot(self) -> None: |
||||
raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.") |
||||
|
||||
def build_jit(self) -> None: |
||||
raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.") |
||||
|
||||
def load(self): |
||||
from typing import Optional |
||||
|
||||
import torch |
||||
|
||||
def flash_attention( |
||||
q: torch.Tensor, |
||||
k: torch.Tensor, |
||||
v: torch.Tensor, |
||||
dropout_p: float = 0.0, |
||||
scale: Optional[float] = None, |
||||
attention_mask: Optional[torch.Tensor] = None, |
||||
is_causal: bool = False, |
||||
cu_seqlens_q: Optional[torch.Tensor] = None, |
||||
cu_seqlens_kv: Optional[torch.Tensor] = None, |
||||
max_seqlen_q: Optional[int] = None, |
||||
max_seqlen_kv: Optional[int] = None, |
||||
q_indices: Optional[torch.Tensor] = None, |
||||
kv_indices: Optional[torch.Tensor] = None, |
||||
): |
||||
return torch.nn.functional.scaled_dot_product_attention( |
||||
q, |
||||
k, |
||||
v, |
||||
attn_mask=attention_mask, |
||||
dropout_p=dropout_p, |
||||
scale=scale, |
||||
) |
||||
|
||||
return flash_attention |
@ -1,94 +0,0 @@
|
||||
from ..base_extension import _Extension |
||||
|
||||
|
||||
class FlashAttentionXformersCudaExtension(_Extension): |
||||
def __init__(self): |
||||
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) |
||||
|
||||
def is_hardware_available(self) -> bool: |
||||
# cuda extension can only be built if cuda is available |
||||
try: |
||||
import torch |
||||
|
||||
cuda_available = torch.cuda.is_available() |
||||
except: |
||||
cuda_available = False |
||||
return cuda_available |
||||
|
||||
def assert_hardware_compatible(self) -> bool: |
||||
pass |
||||
|
||||
def build_aot(self) -> None: |
||||
raise NotImplementedError( |
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." |
||||
) |
||||
|
||||
def build_jit(self) -> None: |
||||
raise NotImplementedError( |
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." |
||||
) |
||||
|
||||
def load(self): |
||||
try: |
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention |
||||
from xformers.ops.fmha.attn_bias import ( |
||||
BlockDiagonalCausalMask, |
||||
BlockDiagonalMask, |
||||
LowerTriangularMask, |
||||
LowerTriangularMaskWithTensorBias, |
||||
) |
||||
except ImportError: |
||||
raise ModuleNotFoundError( |
||||
( |
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." |
||||
) |
||||
) |
||||
from typing import Optional |
||||
|
||||
import torch |
||||
|
||||
allow_alibi = True |
||||
for op in MemoryEfficientAttentionCutlassOp: |
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) |
||||
|
||||
def mem_eff_attention( |
||||
q: torch.Tensor, |
||||
k: torch.Tensor, |
||||
v: torch.Tensor, |
||||
seq_len_info_q: "SeqLenInfo", |
||||
seq_len_info_kv: "SeqLenInfo", |
||||
origin_attn_mask: Optional[torch.Tensor] = None, |
||||
bias: Optional[torch.Tensor] = None, |
||||
dropout_p: float = 0.0, |
||||
scale: float = None, |
||||
causal: bool = False, |
||||
padded: bool = False, |
||||
): |
||||
attn_bias = None |
||||
if padded: # bert style |
||||
if not causal: |
||||
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) |
||||
else: |
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) |
||||
elif causal: # gpt style |
||||
attn_bias = LowerTriangularMask() |
||||
|
||||
if bias is not None: # alibi / relative position embedding |
||||
assert allow_alibi, "flash attention with bias is not supported in this system." |
||||
assert causal, "attention with bias is only supported for causal attention so far." |
||||
attn_bias = attn_bias.add_bias(bias) |
||||
|
||||
if padded: |
||||
q = q.unsqueeze(0) |
||||
k = k.unsqueeze(0) |
||||
v = v.unsqueeze(0) |
||||
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) |
||||
|
||||
# shape: (b*s, n, d) |
||||
if padded: |
||||
out = out.squeeze(0) |
||||
|
||||
return out |
||||
|
||||
return mem_eff_attention |
@ -0,0 +1,147 @@
|
||||
import math |
||||
from copy import copy |
||||
|
||||
import torch |
||||
from torch.testing import assert_close |
||||
|
||||
from colossalai.kernel.kernel_loader import ( |
||||
FlashAttentionLoader, |
||||
FlashAttentionWithCustomMaskLoader, |
||||
FlashAttentionWithPaddingMaskLoader, |
||||
) |
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention |
||||
from colossalai.shardformer.layer.attn import invert_mask |
||||
from colossalai.testing import clear_cache_before_run, parameterize |
||||
from colossalai.utils import get_current_device, set_seed |
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16] |
||||
B, N, S, D = 2, 8, 256, 32 |
||||
|
||||
TOL_MAP = { |
||||
torch.float16: {"atol": 5e-4, "rtol": 2e-3}, |
||||
torch.bfloat16: {}, |
||||
} |
||||
|
||||
|
||||
def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0): |
||||
head_dim = q.size(-1) |
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) |
||||
if attn_mask is not None: |
||||
attn_weights = attn_weights + attn_mask |
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype) |
||||
attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) |
||||
attn_output = torch.matmul(attn_weights, v) |
||||
return attn_output |
||||
|
||||
|
||||
def gen_padded_kwargs(dtype: torch.dtype): |
||||
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) |
||||
padding_mask[0, : S // 4] = 0 |
||||
return ( |
||||
ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask), |
||||
padding_mask, |
||||
) |
||||
|
||||
|
||||
def gen_padded_causal_kwargs(dtype: torch.dtype): |
||||
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) |
||||
padding_mask[0, S // 2 :] = 0 |
||||
return ( |
||||
ColoAttention.prepare_attn_kwargs( |
||||
(B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True |
||||
), |
||||
padding_mask, |
||||
) |
||||
|
||||
|
||||
def gen_causal_kwargs(dtype: torch.dtype): |
||||
return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None |
||||
|
||||
|
||||
def gen_custom_kwargs(dtype: torch.dtype): |
||||
attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device()) |
||||
attn_mask[0, : S // 2, S // 2 :] = 0 |
||||
attn_mask[0, S // 2 :, : S // 2] = 0 |
||||
attn_mask[1, :, S // 4 :] = 0 |
||||
attn_mask = invert_mask(attn_mask).unsqueeze(1) |
||||
assert not torch.all(attn_mask != 0, dim=-1).any() |
||||
return {"attention_mask": attn_mask}, None |
||||
|
||||
|
||||
def post_process_kwargs_for_raw_attn(attn_kwargs: dict): |
||||
if "attention_mask_type" in attn_kwargs: |
||||
attn_kwargs = copy(attn_kwargs) |
||||
mask_type = attn_kwargs.pop("attention_mask_type") |
||||
attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) |
||||
return attn_kwargs |
||||
|
||||
|
||||
def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None): |
||||
tols = TOL_MAP[dtype] |
||||
q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) |
||||
k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) |
||||
v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) |
||||
q_flash = q.clone().detach().requires_grad_(True) |
||||
k_flash = k.clone().detach().requires_grad_(True) |
||||
v_flash = v.clone().detach().requires_grad_(True) |
||||
attn_mask = attn_kwargs.get("attention_mask", None) |
||||
ref_output = attention_ref(q, k, v, attn_mask) |
||||
output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs) |
||||
if padding_mask is not None: |
||||
# [B, Sq] -> [B, 1, Sq, 1] |
||||
padding_mask = padding_mask[:, None, :, None].logical_not() |
||||
ref_output = ref_output.masked_fill(padding_mask, 0) |
||||
output = output.masked_fill(padding_mask, 0) |
||||
assert_close(output, ref_output, **tols) |
||||
output.mean().backward() |
||||
ref_output.mean().backward() |
||||
assert_close(q.grad, q_flash.grad, **tols) |
||||
assert_close(k.grad, k_flash.grad, **tols) |
||||
assert_close(v.grad, v_flash.grad, **tols) |
||||
|
||||
|
||||
@clear_cache_before_run() |
||||
@parameterize("dtype", DTYPE) |
||||
def test_flash_attn_func(dtype: torch.dtype): |
||||
torch.backends.cudnn.deterministic = True |
||||
set_seed(0) |
||||
# (func, name, need_postprocess) |
||||
avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)] |
||||
avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] |
||||
avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] |
||||
for ext_cls in FlashAttentionLoader.REGISTRY: |
||||
ext = ext_cls() |
||||
if ext.is_available(): |
||||
ext.assert_compatible() |
||||
avail_attn_funcs.append((ext.load(), ext.name, True)) |
||||
for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY: |
||||
ext = ext_cls() |
||||
if ext.is_available(): |
||||
ext.assert_compatible() |
||||
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) |
||||
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: |
||||
ext = ext_cls() |
||||
if ext.is_available(): |
||||
ext.assert_compatible() |
||||
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) |
||||
|
||||
test_sets = { |
||||
"none": (lambda dtype: ({}, None), avail_attn_funcs), |
||||
"padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs), |
||||
"padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs), |
||||
"causal": (gen_causal_kwargs, avail_attn_funcs), |
||||
"custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs), |
||||
} |
||||
|
||||
for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items(): |
||||
attn_kwargs, padding_mask = gen_kwargs_func(dtype) |
||||
for attn_func, name, need_postprocess in attn_funcs: |
||||
print(f"{dtype}, {name}, {mask_type}") |
||||
if need_postprocess: |
||||
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) |
||||
else: |
||||
check_attn_func(dtype, attn_func, attn_kwargs, padding_mask) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_flash_attn_func() |
@ -1,167 +0,0 @@
|
||||
import math |
||||
|
||||
import pytest |
||||
import torch |
||||
from einops import rearrange |
||||
|
||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN |
||||
from colossalai.testing import clear_cache_before_run, parameterize |
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: |
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention |
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32] |
||||
|
||||
|
||||
def attention_ref(q, k, v, attn_mask=None, causal=False): |
||||
""" |
||||
attention output of the control group |
||||
""" |
||||
dtype_og = q.dtype |
||||
seqlen_q, seqlen_k = q.shape[1], k.shape[1] |
||||
d = q.shape[-1] |
||||
scale = 1.0 / math.sqrt(d) |
||||
scores = torch.einsum("bthd,bshd->bhts", q * scale, k) |
||||
|
||||
if attn_mask is not None: |
||||
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) |
||||
if causal: |
||||
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) |
||||
scores.masked_fill_(causal_mask, float("-inf")) |
||||
attention = torch.softmax(scores, dim=-1) |
||||
|
||||
output = torch.einsum("bhts,bshd->bthd", attention, v) |
||||
output = rearrange(output, "b s h d -> b s (h d)") |
||||
|
||||
# Modify the data at the positions of the mask to 0 |
||||
if attn_mask is not None: |
||||
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) |
||||
|
||||
return output.to(dtype=dtype_og) |
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") |
||||
@clear_cache_before_run() |
||||
@parameterize("proj_shape", [(6, 8, 4, 16)]) |
||||
@parameterize("dtype", DTYPE) |
||||
@parameterize("dropout", [0.0]) |
||||
def test_attention_gpt(proj_shape, dtype, dropout): |
||||
(B, S, H, D_HEAD) = proj_shape |
||||
D = H * D_HEAD |
||||
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
|
||||
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] |
||||
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) |
||||
|
||||
attn = ColoAttention(D, H, dropout=dropout) |
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) |
||||
|
||||
assert list(y.shape) == [B, S, D] |
||||
|
||||
out_ref = attention_ref(q, k, v, mask, causal=True) |
||||
|
||||
# check gradients |
||||
dy = torch.rand_like(y) |
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) |
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) |
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" |
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" |
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" |
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" |
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") |
||||
@clear_cache_before_run() |
||||
@parameterize("proj_shape", [(6, 8, 4, 16)]) |
||||
@parameterize("dtype", DTYPE) |
||||
@parameterize("dropout", [0.0]) |
||||
def test_attention_bert(proj_shape, dtype, dropout): |
||||
(B, S, H, D_HEAD) = proj_shape |
||||
D = H * D_HEAD |
||||
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
|
||||
# attention mask of shape [B, S] with zero padding to max length S |
||||
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") |
||||
|
||||
attn = ColoAttention(D, H, dropout=dropout) |
||||
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) |
||||
|
||||
assert list(y.shape) == [B, S, D] |
||||
|
||||
out_ref = attention_ref(q, k, v, mask, causal=False) |
||||
|
||||
dy = torch.rand_like(y) |
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) |
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) |
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" |
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" |
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" |
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" |
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") |
||||
@clear_cache_before_run() |
||||
@parameterize("proj_shape", [(6, 8, 4, 16)]) |
||||
@parameterize("dtype", DTYPE) |
||||
@parameterize("dropout", [0.0]) |
||||
def test_attention_no_mask(proj_shape, dtype, dropout): |
||||
(B, S, H, D_HEAD) = proj_shape |
||||
D = H * D_HEAD |
||||
|
||||
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
|
||||
attn = ColoAttention(D, H, dropout=dropout) |
||||
y = attn(q, k, v) |
||||
|
||||
assert list(y.shape) == [B, S, D] |
||||
|
||||
out_ref = attention_ref(q, k, v, None, causal=False) |
||||
|
||||
dy = torch.rand_like(y) |
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) |
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) |
||||
|
||||
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" |
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" |
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" |
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" |
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") |
||||
@clear_cache_before_run() |
||||
@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) |
||||
@parameterize("dtype", DTYPE) |
||||
@parameterize("dropout", [0.0]) |
||||
def test_cross_attention(proj_shape, dtype, dropout): |
||||
(B, S, T, H, D_HEAD) = proj_shape |
||||
D = H * D_HEAD |
||||
|
||||
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) |
||||
|
||||
attn = ColoAttention(D, H, dropout=dropout) |
||||
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) |
||||
|
||||
assert list(y.shape) == [B, T, D] |
||||
|
||||
out_ref = attention_ref(q, k, v, None, causal=True) |
||||
|
||||
dy = torch.rand_like(y) |
||||
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) |
||||
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) |
||||
|
||||
torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" |
||||
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" |
||||
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" |
||||
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" |
Loading…
Reference in new issue