2023-08-04 05:46:22 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Iterable, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
|
|
|
|
|
|
|
|
class Unpad(torch.autograd.Function):
|
|
|
|
"""
|
|
|
|
Adapted from
|
|
|
|
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
|
|
|
ctx.save_for_backward(indices)
|
|
|
|
# [b, s, ...]
|
|
|
|
assert tensor.ndim >= 3
|
|
|
|
ctx.bsz = tensor.shape[0]
|
2023-09-19 06:20:26 +00:00
|
|
|
out = rearrange(tensor, "b s ... -> (b s) ...")
|
2023-08-04 05:46:22 +00:00
|
|
|
ctx.shape = out.shape
|
|
|
|
# [ntokens, ...]
|
|
|
|
return out[indices]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
2023-09-19 06:20:26 +00:00
|
|
|
(indices,) = ctx.saved_tensors
|
2023-08-04 05:46:22 +00:00
|
|
|
# [ntokens, ...]
|
|
|
|
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
|
|
|
grad[indices] = grad_output
|
2023-09-19 06:20:26 +00:00
|
|
|
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
|
2023-08-04 05:46:22 +00:00
|
|
|
# [b, s, ...]
|
|
|
|
return grad, None
|
|
|
|
|
|
|
|
|
|
|
|
class Repad(torch.autograd.Function):
|
|
|
|
"""
|
|
|
|
Adapted from
|
|
|
|
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
|
|
|
ctx.save_for_backward(indices)
|
|
|
|
# [ntokens, ...]
|
|
|
|
tensor = tensor
|
|
|
|
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
|
|
|
# [b*s, ...]
|
|
|
|
out[indices] = tensor
|
|
|
|
return out
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
2023-09-19 06:20:26 +00:00
|
|
|
(indices,) = ctx.saved_tensors
|
2023-08-04 05:46:22 +00:00
|
|
|
# [b*s, ...]
|
|
|
|
grad = grad_output[indices]
|
|
|
|
# [ntokens, ...]
|
|
|
|
return grad, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class SeqLenInfo:
|
|
|
|
seqlens: Iterable[int] = None
|
|
|
|
indices: torch.Tensor = None
|
|
|
|
max_seqlen: int = None
|
|
|
|
cu_seqlens: torch.Tensor = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
|
|
|
|
if attn_mask is not None:
|
|
|
|
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
|
|
|
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
|
|
|
else:
|
|
|
|
batch_size, tgt_len = size[0], size[1]
|
|
|
|
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
|
|
|
|
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
|
|
|
|
max_seqlen = max(seqlens)
|
|
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
|
|
|
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|