mirror of https://github.com/hpcaitech/ColossalAI
210 lines
7.5 KiB
Python
210 lines
7.5 KiB
Python
import enum
|
|
import math
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import Iterable, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.kernel.kernel_loader import FlashAttentionLoader
|
|
|
|
|
|
@dataclass
|
|
class SeqLenInfo:
|
|
seqlens: Iterable[int] = None
|
|
indices: torch.Tensor = None
|
|
max_seqlen: int = None
|
|
cu_seqlens: torch.Tensor = None
|
|
|
|
@staticmethod
|
|
def materialize(
|
|
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
|
|
):
|
|
if attn_mask is not None:
|
|
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
|
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
|
else:
|
|
batch_size, tgt_len = size[0], size[1]
|
|
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
|
|
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
|
|
max_seqlen = max(seqlens)
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
|
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|
|
|
|
|
|
class AttnMaskType(enum.Enum):
|
|
padding = 1
|
|
causal = 2
|
|
paddedcausal = 3
|
|
|
|
|
|
class Unpad(torch.autograd.Function):
|
|
"""
|
|
Adapted from
|
|
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
|
ctx.save_for_backward(indices)
|
|
# [b, s, ...]
|
|
assert tensor.ndim >= 3
|
|
ctx.bsz = tensor.shape[0]
|
|
out = rearrange(tensor, "b s ... -> (b s) ...")
|
|
ctx.shape = out.shape
|
|
# [ntokens, ...]
|
|
return out[indices]
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(indices,) = ctx.saved_tensors
|
|
# [ntokens, ...]
|
|
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
|
grad[indices] = grad_output
|
|
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
|
|
# [b, s, ...]
|
|
return grad, None
|
|
|
|
|
|
class Repad(torch.autograd.Function):
|
|
"""
|
|
Adapted from
|
|
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
|
ctx.save_for_backward(indices)
|
|
# [ntokens, ...]
|
|
tensor = tensor
|
|
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
|
# [b*s, ...]
|
|
out[indices] = tensor
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(indices,) = ctx.saved_tensors
|
|
# [b*s, ...]
|
|
grad = grad_output[indices]
|
|
# [ntokens, ...]
|
|
return grad, None, None, None
|
|
|
|
|
|
class ColoAttention(torch.nn.Module):
|
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
|
super().__init__()
|
|
assert (
|
|
embed_dim % num_heads == 0
|
|
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
|
if scale is not None:
|
|
self.scale = scale
|
|
else:
|
|
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
|
self.dropout = dropout
|
|
|
|
self.attn = FlashAttentionLoader().load()
|
|
|
|
@staticmethod
|
|
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
|
return Unpad.apply(tensor, indices)
|
|
|
|
@staticmethod
|
|
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
|
return Repad.apply(tensor, indices, batch_size, seq_len)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
origin_attn_mask: Optional[torch.Tensor] = None,
|
|
attn_mask_type: Optional[AttnMaskType] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
ColoAttention
|
|
|
|
Args:
|
|
q: (batch, q_seqlen, nheads, headdim)
|
|
k: (batch, kv_seqlen, nheads, headdim)
|
|
v: (batch, kv_seqlen, nheads, headdim)
|
|
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
|
|
bias: will not be used
|
|
Return:
|
|
attn_out: (batch, q_seqlen, nheads, headdim).
|
|
"""
|
|
# if flash attention is not applicable, switch to memory effcient attention
|
|
if self.attn.__name__ == "flash_attention" and (
|
|
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
|
|
):
|
|
warnings.warn(
|
|
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
|
|
)
|
|
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
|
|
|
|
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
|
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
|
|
|
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
|
# unpad
|
|
seq_len_info_q = None
|
|
seq_len_info_kv = None
|
|
if padded:
|
|
# bert style, unpad process
|
|
assert (
|
|
attn_mask is not None
|
|
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
|
assert attn_mask.dim() == 2, (
|
|
"attention mask is supposed to have shape (batch_size, seq_len), "
|
|
+ f"but got {attn_mask.dim()} dimensions."
|
|
)
|
|
|
|
# bert style
|
|
if tgt_len == src_len:
|
|
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
|
if batch_size > 1:
|
|
query, key, value = self.unpad(
|
|
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
|
).unbind(dim=1)
|
|
else:
|
|
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
|
seq_len_info_kv = seq_len_info_q
|
|
else:
|
|
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
|
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
|
if batch_size > 1:
|
|
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
|
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
|
dim=1
|
|
)
|
|
else:
|
|
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
|
|
|
out = self.attn(
|
|
query,
|
|
key,
|
|
value,
|
|
seq_len_info_q=seq_len_info_q,
|
|
seq_len_info_kv=seq_len_info_kv,
|
|
origin_attn_mask=origin_attn_mask,
|
|
dropout_p=self.dropout,
|
|
scale=self.scale,
|
|
causal=causal,
|
|
padded=padded,
|
|
)
|
|
|
|
# repad
|
|
if padded:
|
|
if batch_size > 1:
|
|
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
|
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
|
|
|
if len(out.shape) == 4:
|
|
out = rearrange(out, "b s h d -> b s (h d)")
|
|
return out
|