diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index b01d15490..5bdadca78 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -117,7 +117,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - name: Store TensorNVMe Cache run: | @@ -201,4 +201,4 @@ jobs: uses: actions/upload-artifact@v3 with: name: report - path: report/ \ No newline at end of file + path: report/ diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 3ff19b37b..e560d0c00 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -44,7 +44,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 if: steps.check-avai.outputs.avai == 'true' diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 764938806..95a94c27b 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -66,7 +66,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index f582b3090..aef4816ef 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -60,7 +60,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 3348b51ec..3dc8a5a32 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -56,7 +56,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 148c3e3fc..353e29b3d 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -6,7 +6,7 @@ from .extensions import ( CpuAdamX86Extension, FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FusedOptimizerCudaExtension, LayerNormCudaExtension, MoeCudaExtension, @@ -65,9 +65,9 @@ class KernelLoader: else: usable_exts = [] for ext in exts: - if ext.is_hardware_available(): + if ext.is_available(): # make sure the machine is compatible during kernel loading - ext.assert_hardware_compatible() + ext.assert_compatible() usable_exts.append(ext) assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." @@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): class FlashAttentionLoader(KernelLoader): - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] + REGISTRY = [ + FlashAttentionNpuExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionSdpaCudaExtension, + ] + + +class FlashAttentionWithPaddingMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] + + +class FlashAttentionWithCustomMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] + + +class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionSdpaCudaExtension] diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py deleted file mode 100644 index 0b7011e8e..000000000 --- a/colossalai/nn/layer/colo_attention.py +++ /dev/null @@ -1,209 +0,0 @@ -import enum -import math -import warnings -from dataclasses import dataclass -from typing import Iterable, Optional, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.accelerator import get_accelerator -from colossalai.kernel.kernel_loader import FlashAttentionLoader - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize( - attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() - ): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - self.attn = FlashAttentionLoader().load() - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - """ - ColoAttention - - Args: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - origin_attn_mask: (nheads, q_seqlen, kv_seqlen) - bias: will not be used - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - # if flash attention is not applicable, switch to memory effcient attention - if self.attn.__name__ == "flash_attention" and ( - query.dtype not in [torch.float16, torch.bfloat16] or bias != None - ): - warnings.warn( - f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." - ) - self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = self.attn( - query, - key, - value, - seq_len_info_q=seq_len_info_q, - seq_len_info_kv=seq_len_info_kv, - origin_attn_mask=origin_attn_mask, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - if len(out.shape) == 4: - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 56e8b08c4..c9b4317a6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,3 +1,4 @@ +from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -23,4 +24,6 @@ __all__ = [ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "AttnMaskType", + "ColoAttention", ] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py new file mode 100644 index 000000000..f3f6e59d3 --- /dev/null +++ b/colossalai/shardformer/layer/attn.py @@ -0,0 +1,269 @@ +from enum import Enum +from typing import Callable, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from colossalai.kernel.kernel_loader import ( + FlashAttentionForFloatAndCustomMaskLoader, + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, + KernelLoader, +) + +__all__ = [ + "AttnMaskType", + "ColoAttention", +] + + +class AttnMaskType(Enum): + CUSTOM = 0 + PADDED = 1 + CAUSAL = 2 + PADDED_CAUSAL = 3 + + +def invert_mask(mask: torch.Tensor) -> torch.Tensor: + """Invert the mask tensor. + + Args: + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] + + Returns: + torch.Tensor: Inverted mask tensor. + """ + inverted_mask = 1.0 - mask + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) + + +# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: + """Get padding information from padding mask. + + Args: + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] + + Returns: + Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) + """ + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return max_seqlen_in_batch, cu_seqlens, indices + + +class ColoAttention: + _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None + + @staticmethod + def _init_kernels_dispatch(): + if ColoAttention._kernel_dispatch_map is None: + # fp16/bf16 + half_dispatch_map = { + None: FlashAttentionLoader(), + AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), + AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.CAUSAL: FlashAttentionLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), + } + # fp32 + float_dispatch_map = { + None: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), + } + ColoAttention._kernel_dispatch_map = { + torch.float16: half_dispatch_map, + torch.bfloat16: half_dispatch_map, + torch.float32: float_dispatch_map, + } + + @staticmethod + def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: + ColoAttention._init_kernels_dispatch() + if ( + dtype not in ColoAttention._kernel_dispatch_map + or mask_type not in ColoAttention._kernel_dispatch_map[dtype] + ): + raise ValueError( + "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) + ) + # lazy load + if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): + ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ + mask_type + ].load() + return ColoAttention._kernel_dispatch_map[dtype][mask_type] + + @staticmethod + def prepare_attn_kwargs( + shape_4d: Tuple[int], + dtype: torch.dtype, + device: torch.device, + q_padding_mask: Optional[torch.Tensor] = None, + kv_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + ) -> Dict[str, torch.Tensor]: + """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. + 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. + 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. + 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + + Args: + shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) + dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` + device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` + q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. + The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. + kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. + The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. + If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. + is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. + + Returns: + Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. + """ + if q_padding_mask is None and not is_causal: + return {} + assert len(shape_4d) == 4 and shape_4d[1] == 1 + b, _, s_q, s_kv = shape_4d + outputs = {} + if (q_padding_mask is None or q_padding_mask.bool().all()) and ( + kv_padding_mask is None or kv_padding_mask.bool().all() + ): + # no padding + assert is_causal + outputs["attention_mask_type"] = AttnMaskType.CAUSAL + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) + else: + if kv_padding_mask is None: + # self attention + kv_padding_mask = q_padding_mask + assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( + b, + s_kv, + ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + outputs.update( + { + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_kv": cu_seqlens_kv, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_kv": max_seqlen_kv, + "q_indices": q_indices, + "kv_indices": kv_indices, + } + ) + if is_causal: + outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + outputs["attention_mask_type"] = AttnMaskType.PADDED + attention_mask = invert_mask(attention_mask).unsqueeze(1) + outputs["attention_mask"] = attention_mask + return outputs + + @staticmethod + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + ) -> torch.Tensor: + """Flash Attention function. It supports 4 mask type. + 1. custom mask: recv attention_mask + 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + 3. causal mask: recv attention_mask, attention_mask_type + 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + + Args: + q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] + v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] + attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. + attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. + cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into q. + Shape should be [B+1]. Defaults to None. + cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + Shape should be [B+1]. Defaults to None. + max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. + max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. + indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. + Shape should be [NUM_TOKENS]. Defaults to None. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. + + Returns: + torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] + """ + # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan + # this case is usaul when padding mask is used and self attention is performed + # thus, we don't use sdpa when padding mask is used + # sanity check + if attention_mask is not None: + assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): + assert ( + cu_seqlens_q is None + and cu_seqlens_kv is None + and max_seqlen_q is None + and max_seqlen_kv is None + and q_indices is None + and kv_indices is None + ) + if attention_mask_type == AttnMaskType.CUSTOM: + assert not torch.all(attention_mask != 0, dim=-1).any() + elif attention_mask_type in ( + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): + assert ( + cu_seqlens_q is not None + and cu_seqlens_kv is not None + and max_seqlen_q is not None + and max_seqlen_kv is not None + and q_indices is not None + and kv_indices is not None + ) + else: + # if attention_mask is None, attention_mask_type should be the default value + assert attention_mask_type == AttnMaskType.CUSTOM + # kernel dispatch + mask_type = attention_mask_type if attention_mask is not None else None + attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) + is_causal = attention_mask is not None and attention_mask_type in ( + AttnMaskType.CAUSAL, + AttnMaskType.PADDED_CAUSAL, + ) + return attn_func( + q, + k, + v, + dropout_p=dropout_p, + scale=scale, + attention_mask=attention_mask, + is_causal=is_causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + q_indices=q_indices, + kv_indices=kv_indices, + ) diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index d5c10541a..bd84c87c6 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -3,6 +3,8 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from colossalai.shardformer.layer import ColoAttention + def forward_fn(): def forward( @@ -62,8 +64,6 @@ def forward_fn(): def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.nn.layer.colo_attention import ColoAttention - def forward( self: Blip2Attention, hidden_states: torch.Tensor, @@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward(): output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert head_mask is None, "head_mask is not supported in FlashAttention" bsz, tgt_len, embed_dim = hidden_states.size() mixed_qkv = self.qkv(hidden_states) - mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) - query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale + dropout_p = self.dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention( + query_states, + key_states, + value_states, + dropout_p=dropout_p, + scale=self.scale, ) - context_layer = attention(query_states, key_states, value_states) + context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim) output = self.projection(context_layer) outputs = (output, None) @@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward(): def get_jit_fused_blip2_QFormer_self_output_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput - def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self: Blip2QFormerSelfOutput, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.LayerNorm(hidden_states) @@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward(): def get_jit_fused_blip2_QFormer_output_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput - def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self: Blip2QFormerOutput, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.LayerNorm(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index d13bd3492..a3e000e6e 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -1,4 +1,5 @@ """ PyTorch ChatGLM model. """ + from typing import List, Optional, Tuple import torch @@ -9,63 +10,49 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - from .chatglm2_6b.modeling_chatglm import CoreAttention def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split(".")[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, is_causal=True - ) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, attention_mask - ) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - query_layer = query_layer.permute(1, 0, 2, 3).contiguous() - key_layer = key_layer.permute(1, 0, 2, 3).contiguous() - value_layer = value_layer.permute(1, 0, 2, 3).contiguous() - - scale = 1.0 / self.norm_factor - if self.coeff is not None: - scale = scale * self.coeff - - flash_attention_mask = None - attn_mask_type = None - if attention_mask is None: - attn_mask_type = AttnMaskType.causal - else: - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention( - embed_dim=self.hidden_size_per_partition, - num_heads=self.num_attention_heads_per_partition, - dropout=self.attention_dropout.p, - scale=scale, + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + attention_mask_type = AttnMaskType.CAUSAL + attn_bias = torch.zeros( + query_layer.shape[0], + 1, + query_layer.shape[2], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, ) - context_layer = attention( - query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + temp_mask = ( + torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) + .tril(diagonal=0) + .expand(query_layer.shape[0], 1, -1, -1) ) - - context_layer = context_layer.permute(1, 0, -1).contiguous() - + attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min) + else: + attention_mask_type = AttnMaskType.CUSTOM + if attention_mask is not None: + attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype) + attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min) + dropout_p = self.attention_dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention( + query_layer, + key_layer, + value_layer, + attention_mask=attn_bias, + attention_mask_type=attention_mask_type, + dropout_p=dropout_p, + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer return forward @@ -169,11 +156,17 @@ class ChatGLMPipelineForwards: if self.pre_seq_len is not None: if past_key_values is None: past_key_values = self.get_prompt( - batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, ) if attention_mask is not None: attention_mask = torch.cat( - [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): @@ -200,7 +193,9 @@ class ChatGLMPipelineForwards: if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) @@ -208,7 +203,12 @@ class ChatGLMPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.encoder.gradient_checkpointing and self.encoder.training: layer_ret = torch.utils.checkpoint.checkpoint( - layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + past_key_values[idx], + use_cache, ) else: layer_ret = layer( @@ -224,7 +224,9 @@ class ChatGLMPipelineForwards: if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -234,7 +236,14 @@ class ChatGLMPipelineForwards: hidden_states = self.encoder.final_layernorm(hidden_states) if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] inputs_embeds = split_forward_gather_backward( - inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group + inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, @@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ) hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 407338b16..72f923bf0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d from ..layer._operation import gather_forward_split_backward +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: GPT2Model, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], + attention_mask: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.Tensor], + encoder_attention_mask: Optional[torch.FloatTensor], +) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: + batch_size, seq_len = hidden_states.shape[:2] + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + if shard_config.enable_flash_attention: + encoder_attention_mask = ColoAttention.prepare_attn_kwargs( + (encoder_batch_size, 1, seq_len, encoder_sequence_length), + dtype=hidden_states.dtype, + dtype2=encoder_hidden_states.dtype, + q_padding_mask=attention_mask, + kv_padding_mask=encoder_attention_mask, + ) + else: + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + if shard_config.enable_flash_attention: + encoder_attention_mask = {"attention_mask": None} + else: + encoder_attention_mask = None + # GPT2Attention mask. + past_key_values_length = 0 + if past_key_values is not None and past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if shard_config.enable_flash_attention: + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_len, seq_len + past_key_values_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + elif attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + return attention_mask, encoder_attention_mask + class GPT2PipelineForwards: """ @@ -83,10 +153,10 @@ class GPT2PipelineForwards: elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -99,38 +169,7 @@ class GPT2PipelineForwards: input_shape = hidden_states.size()[:-1] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) - batch_size = hidden_states.shape[0] - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None + hidden_states.shape[0] # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -156,6 +195,16 @@ class GPT2PipelineForwards: output_shape = input_shape + (hidden_states.size(-1),) + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -171,7 +220,9 @@ class GPT2PipelineForwards: # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) # Going through held blocks. @@ -180,7 +231,7 @@ class GPT2PipelineForwards: block = self.h[i] torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: + if torch.is_tensor(attention_mask): attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) @@ -229,7 +280,9 @@ class GPT2PipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) if stage_manager.is_last_stage(): @@ -245,7 +298,13 @@ class GPT2PipelineForwards: if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) @@ -333,7 +392,9 @@ class GPT2PipelineForwards: shift_labels = shift_labels.view(-1) if shard_config.enable_tensor_parallelism and shard_config.parallel_output: loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) else: loss = loss_fct(shift_logits, shift_labels) @@ -733,27 +794,18 @@ class GPT2PipelineForwards: def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def split_heads(tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - def forward( self: GPT2Attention, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[dict] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[dict] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + assert head_mask is None, "FlashAttention does not support head_mask" if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( @@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward(): attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = split_heads(query, self.num_heads, self.head_dim) - key = split_heads(key, self.num_heads, self.head_dim) - value = split_heads(value, self.num_heads, self.head_dim) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward(): else: present = None - if not self.is_cross_attention: - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None - if attention_mask != None: - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding - - scale = value.size(-1) ** -0.5 + scale = 1.0 + if self.scale_attn_weights: + scale /= value.size(-1) ** 0.5 if self.scale_attn_by_inverse_layer_idx: - scale = scale * (1 / float(self.layer_idx + 1)) - - # use coloattention - if not hasattr(self, "attention"): - self.attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) - - attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - + scale /= float(self.layer_idx + 1) + dropout_p = self.attn_dropout.p if self.training else 0.0 + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) @@ -813,9 +849,9 @@ def get_gpt2_flash_attention_forward(): return forward -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): +def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): def forward( - self, + self: GPT2Model, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, @@ -840,12 +876,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -862,39 +899,201 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): else: past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if torch.is_tensor(attention_mask): + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] else: - encoder_attention_mask = None + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if layer_past is not None: layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: + if torch.is_tensor(attention_mask): attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) @@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) @@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 1990d7df3..5c254d1e7 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -19,9 +19,54 @@ from transformers.models.gptj.modeling_gptj import ( from transformers.utils import is_torch_fx_proxy, logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: GPTJModel, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], + attention_mask: Optional[torch.FloatTensor], +) -> Optional[Union[torch.Tensor, dict]]: + batch_size, seq_len = hidden_states.shape[:2] + past_key_values_length = 0 + if past_key_values is not None and past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if shard_config.enable_flash_attention: + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_len, seq_len + past_key_values_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + elif attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + return attention_mask + class GPTJPipelineForwards: """ @@ -96,26 +141,6 @@ class GPTJPipelineForwards: batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_attention_heads x N x N @@ -139,6 +164,8 @@ class GPTJPipelineForwards: output_shape = input_shape + (hidden_states.size(-1),) + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -154,7 +181,9 @@ class GPTJPipelineForwards: # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) # Going through held blocks. @@ -209,7 +238,9 @@ class GPTJPipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) if stage_manager.is_last_stage(): @@ -223,7 +254,14 @@ class GPTJPipelineForwards: if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None ) return BaseModelOutputWithPast( @@ -530,24 +568,11 @@ class GPTJPipelineForwards: def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def split_heads(tensor, num_attention_heads, attn_head_size, rotary): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(new_shape) - if rotary or len(tensor.shape) in [4, 5]: - return tensor - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - def forward( self: GPTJAttention, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, @@ -556,13 +581,14 @@ def get_gptj_flash_attention_forward(): Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + assert head_mask is None, "head_mask is not supported for FlashAttention" query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = split_heads(query, self.num_attention_heads, self.head_dim, True) - key = split_heads(key, self.num_attention_heads, self.head_dim, True) - value = split_heads(value, self.num_attention_heads, self.head_dim, False) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -591,46 +617,202 @@ def get_gptj_flash_attention_forward(): key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - # key = key.permute(0, 2, 1, 3) - # query = query.permute(0, 2, 1, 3) - key = key.to(dtype=value.dtype) # fp16 compatibility - query = query.to(dtype=value.dtype) + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) if layer_past is not None: past_key = layer_past[0] past_value = layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None - # use AttnMaskType and ColoAttention - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None - if attention_mask != None: - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + dropout_p = self.attn_dropout.p if self.training else 0.0 + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p) + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) - # use coloattention - scale = value.size(-1) ** -0.5 + return outputs # a, present, (attentions) + + return forward - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale + +def gptj_model_forward_for_flash_attention(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + device = input_ids.device if input_ids is not None else inputs_embeds.device - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present, None) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) - return outputs # a, present, (attentions) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]).long() + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) return forward @@ -662,10 +844,10 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -684,29 +866,14 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_attention_heads x N x N @@ -725,6 +892,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) if self.gradient_checkpointing and self.training: if use_cache: @@ -740,7 +908,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -801,7 +971,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) @@ -812,7 +984,16 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d5e02b64c..1f17144f5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -15,7 +15,9 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d + +from ..layer import ColoAttention, cross_entropy_1d + try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -105,18 +107,25 @@ class LlamaPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + if LATEST_VERSION: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -262,6 +271,7 @@ class LlamaPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None @@ -352,6 +362,7 @@ class LlamaPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if input_ids is not None: @@ -420,8 +431,6 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv @@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): def forward( self: LlamaAttention, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, @@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) - - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if not getattr(shard_config, "causal_lm", False) and attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - if not hasattr(self, "attention"): - self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = self.attention( - query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type, - origin_attn_mask=attention_mask, - ) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): return forward +def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + hidden_states = inputs_embeds + + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import LlamaForCausalLM diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index d0e267eac..a26526430 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: OPTModel, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values_length: int, + attention_mask: Optional[torch.FloatTensor], +): + batch_size, seq_length = hidden_states.shape[:2] + mask_seq_length = past_key_values_length + seq_length + if shard_config.enable_flash_attention: + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_length, mask_seq_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + else: + attention_mask = self.decoder._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + return attention_mask class OPTPipelineForwards: @@ -26,46 +57,6 @@ class OPTPipelineForwards: under pipeline setting. """ - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( - device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def opt_model_forward( self: OPTModel, @@ -81,6 +72,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward @@ -119,7 +111,7 @@ class OPTPipelineForwards: if decoder.project_in is not None: inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype + inputs_embeds.dtype else: if hidden_states is None: @@ -127,7 +119,7 @@ class OPTPipelineForwards: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - _dtype = hidden_states.dtype + hidden_states.dtype past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 # required mask seq length can be calculated via length of past @@ -141,13 +133,24 @@ class OPTPipelineForwards: f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( - attention_mask, input_shape, _dtype, device, past_key_values_length - ) - if stage_manager.is_first_stage(): + causal_attention_mask = _get_attention_mask( + self, + shard_config, + inputs_embeds, + past_key_values_length, + attention_mask, + ) pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) hidden_states = inputs_embeds + pos_embeds + else: + causal_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values_length, + attention_mask, + ) if decoder.gradient_checkpointing and decoder.training: if use_cache: @@ -249,7 +252,16 @@ class OPTPipelineForwards: if stage_manager.is_last_stage(): if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -276,6 +288,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. @@ -303,6 +316,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() @@ -347,6 +361,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. @@ -371,6 +386,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -448,6 +464,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. @@ -469,6 +486,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -511,49 +529,47 @@ class OPTPipelineForwards: return {"hidden_states": hidden_states} -def get_opt_flash_attention_forward(): +def get_opt_flash_attention_forward(shard_config: ShardConfig): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - def forward( self: OPTAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(*attention_input_shape) + query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention and past_key_value is not None: - # reuse k, v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -565,38 +581,181 @@ def get_opt_flash_attention_forward(): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal + query_states = self._shape(query_states, tgt_len, bsz) - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling - ) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + dropout_p = self.dropout if self.training else 0.0 + attn_output = ColoAttention.attention( + query_states, + key_states, + value_states, + **attention_mask, + dropout_p=dropout_p, + scale=self.scaling, ) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value return forward +def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig): + from transformers.models.opt.modeling_opt import OPTDecoder + + def forward( + self: OPTDecoder, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _get_attention_mask( + self, shard_config, inputs_embeds, past_key_values_length, attention_mask + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + def get_jit_fused_opt_decoder_layer_forward(): from transformers.models.opt.modeling_opt import OPTDecoderLayer diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ab141a74a..e9c256a13 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,4 +1,3 @@ -import math from typing import List, Optional, Tuple, Union import torch @@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention def _encoder_forward( @@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding, ) hidden_states = embedding_output else: @@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.nn.layer.colo_attention import ColoAttention - - def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) - x = x.view(new_x_shape) - return x - def forward( self: ViTSelfAttention, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + assert head_mask is None, "head_mask is not supported for FlashAttention" mixed_query_layer = self.query(hidden_states) - key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) - value_layer = transpose_for_scores( - self.value(hidden_states), self.num_attention_heads, self.attention_head_size - ) - query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) - scale = 1.0 / math.sqrt(self.attention_head_size) - attention = ColoAttention( - embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale - ) - context_layer = attention(query_layer, key_layer, value_layer) + dropout_p = self.dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer,) + outputs = (context_layer, None) if output_attentions else (context_layer,) return outputs diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cb8b45ae7..7ccc79276 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -13,41 +13,74 @@ from transformers.modeling_outputs import ( SequenceClassifierOutput, ) from transformers.models.whisper.modeling_whisper import ( + WhisperDecoder, WhisperEncoder, WhisperForAudioClassification, WhisperForConditionalGeneration, WhisperModel, + shift_tokens_right, ) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: WhisperDecoder, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values_length: int, + attention_mask: Optional[torch.FloatTensor], +): + batch_size, seq_length = hidden_states.shape[:2] + mask_seq_length = past_key_values_length + seq_length + if shard_config.enable_flash_attention: + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_length, mask_seq_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + return attention_mask def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): - return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() - def forward( self: WhisperAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" + # for encoder, attention_mask is None + if attention_mask is None: + attention_mask = {} # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + # get query proj + query_states = self.q_proj(hidden_states) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward(): if ( is_cross_attention and past_key_value is not None - and past_key_value[0].shape[1] == key_value_states.shape[1] + and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -85,42 +118,178 @@ def get_whisper_flash_attention_forward(): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - # get query proj - query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz) - src_len = key_states.size(1) - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) + dropout_p = self.dropout if self.training else 0.0 + attn_output = ColoAttention.attention( + query_states, + key_states, + value_states, + **attention_mask, + dropout_p=dropout_p, + scale=self.scaling, + ) + attn_output = attn_output.transpose(1, 2) - attn_type = None - flash_attention_mask = None + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - if self.is_decoder: - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) - if not torch.all(flash_attention_mask): - attn_type = AttnMaskType.paddedcausal - else: - attn_type = AttnMaskType.causal + attn_output = self.out_proj(attn_output) - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling - ) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type + return attn_output, None, past_key_value + + return forward + + +def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig): + def forward( + self: WhisperDecoder, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - attn_output = self.out_proj(attn_output) + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - return attn_output, None, past_key_value + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) return forward @@ -292,6 +461,7 @@ class WhisperPipelineForwards: all_attentions=None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Args: @@ -403,7 +573,9 @@ class WhisperPipelineForwards: if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) else: @@ -411,7 +583,7 @@ class WhisperPipelineForwards: @staticmethod def whisper_decoder_forward( - self, + self: WhisperDecoder, input_ids=None, attention_mask=None, encoder_hidden_states=None, @@ -427,6 +599,7 @@ class WhisperPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Args: @@ -535,8 +708,12 @@ class WhisperPipelineForwards: else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + attention_mask = _get_attention_mask( + self, + shard_config, + inputs_embeds, + past_key_values_length, + attention_mask, ) hidden_states = inputs_embeds + positions @@ -556,8 +733,12 @@ class WhisperPipelineForwards: ) input_shape = hidden_states.size()[:-1] - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, hidden_states, past_key_values_length + attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values_length, + attention_mask, ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -590,7 +771,7 @@ class WhisperPipelineForwards: encoder_hidden_states, None, # encoder attention mask head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), None, # past_key_value ) else: @@ -626,7 +807,13 @@ class WhisperPipelineForwards: if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( @@ -666,6 +853,7 @@ class WhisperPipelineForwards: encoder_hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Returns: @@ -735,7 +923,7 @@ class WhisperPipelineForwards: elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None), attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) @@ -767,6 +955,7 @@ class WhisperPipelineForwards: hidden_states=hidden_states, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=shard_config, ) # Directly return outputs of overloaded Whisper forward if not at last stage. @@ -810,6 +999,7 @@ class WhisperPipelineForwards: encoder_hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -870,6 +1060,7 @@ class WhisperPipelineForwards: encoder_hidden_states=encoder_hidden_states, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=shard_config, ) if not in_decoder: return outputs @@ -920,6 +1111,7 @@ class WhisperPipelineForwards: all_attentions=None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6a50d65ba..fcf40fa39 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.gpt2 import ( GPT2PipelineForwards, get_gpt2_flash_attention_forward, + get_gpt_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, gpt2_sequence_parallel_forward_fn, ) @@ -75,7 +76,11 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.c_proj", @@ -87,7 +92,11 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", @@ -150,6 +159,10 @@ class GPT2Policy(Policy): policy=policy, target_key=GPT2Attention, ) + if not self.shard_config.pipeline_stage_manager: + policy[GPT2Model].method_replacement = { + "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) + } if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} @@ -223,14 +236,21 @@ class GPT2Policy(Policy): num_stages=stage_manager.num_stages, ) method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) } else: layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy + model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy, ) return policy @@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] @@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): GPT2DoubleHeadsModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ) ] ) @@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] @@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): addon_module = { GPT2ForTokenClassification: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) ] ) } diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 9feb826c4..b001a2009 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -6,7 +6,11 @@ from torch import Tensor, nn import colossalai.shardformer.layer as col_nn -from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward +from ..modeling.gptj import ( + GPTJPipelineForwards, + get_gptj_flash_attention_forward, + gptj_model_forward_for_flash_attention, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -71,17 +75,26 @@ class GPTJPolicy(Policy): SubModuleReplacementDescription( suffix="attn.k_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.q_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.v_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.out_proj", @@ -143,6 +156,12 @@ class GPTJPolicy(Policy): policy=policy, target_key=GPTJAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)}, + policy=policy, + target_key=GPTJModel, + ) return policy @@ -185,7 +204,10 @@ class GPTJPolicy(Policy): stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy + model_cls=GPTJModel, + new_forward=GPTJPipelineForwards.gptj_model_forward, + policy=policy, ) return policy @@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): GPTJForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ) ] ) @@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy + model_cls=GPTJForCausalLM, + new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, + policy=policy, ) return policy @@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 4c454ac7f..37c2c261b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, + get_llama_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -135,6 +136,15 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaAttention, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_llama_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=LlamaModel, + ) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index a542808ba..9a74da0b8 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func -from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward +from ..modeling.opt import ( + OPTPipelineForwards, + get_jit_fused_opt_decoder_layer_forward, + get_opt_decoder_forward_for_flash_attention, + get_opt_flash_attention_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -27,6 +32,7 @@ class OPTPolicy(Policy): import transformers from packaging.version import Version + # TODO: remove this version check when transformers>=4.36.0 assert Version(transformers.__version__) <= Version( "4.33.0" ), "The OPT model should run on a transformers version not greater than 4.33.0." @@ -111,7 +117,9 @@ class OPTPolicy(Policy): # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="final_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), policy=policy, target_key=OPTDecoder, @@ -119,10 +127,14 @@ class OPTPolicy(Policy): self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="self_attn_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="final_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), ], policy=policy, @@ -133,11 +145,19 @@ class OPTPolicy(Policy): if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_opt_flash_attention_forward(), + "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, target_key=OPTAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={ + "forward": get_opt_decoder_forward_for_flash_attention(self.shard_config), + }, + policy=policy, + target_key=OPTDecoder, + ) # use jit fused operator if self.shard_config.enable_jit_fused: @@ -190,7 +210,14 @@ class OPTPolicy(Policy): layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, + ) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy): policy = super().module_policy() if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy + model_cls=OPTModel, + new_forward=OPTPipelineForwards.opt_model_forward, + policy=policy, ) return policy @@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), ), policy=policy, target_key=OPTForCausalLM, ) if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy + model_cls=OPTForCausalLM, + new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, + policy=policy, ) return policy @@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy): if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): - return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + return [ + { + 0: opt_model.model.decoder.embed_tokens.weight, + num_stages - 1: opt_model.lm_head.weight, + } + ] return [] def postprocess(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index b5b5db79d..14e1e3e0f 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -13,6 +13,7 @@ from ..modeling.whisper import ( WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, + get_whisper_decoder_forward_for_flash_attention, get_whisper_flash_attention_forward, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -31,6 +32,7 @@ class WhisperPolicy(Policy): import transformers from packaging.version import Version + # TODO: remove this version check when transformers>=4.36.0 assert Version(transformers.__version__) <= Version( "4.33.0" ), "The Whisper model should run on a transformers version not greater than 4.33.0." @@ -240,6 +242,14 @@ class WhisperPolicy(Policy): policy=policy, target_key=WhisperAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_decoder_forward_for_flash_attention(self.shard_config), + }, + policy=policy, + target_key=WhisperDecoder, + ) # use jit fused operator if self.shard_config.enable_jit_fused: @@ -269,7 +279,9 @@ class WhisperPolicy(Policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="proj_out", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ), policy=base_policy, target_key=WhisperForConditionalGeneration, @@ -326,7 +338,10 @@ class WhisperPolicy(Policy): if stage < decoder_starting_stage: return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return Policy.get_stage_index( + layers_per_stage[decoder_starting_stage:], + stage - decoder_starting_stage, + ) def get_held_layers(self) -> List[nn.Module]: assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" @@ -422,6 +437,7 @@ class WhisperPolicy(Policy): stage_manager=stage_manager, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -436,7 +452,9 @@ class WhisperModelPolicy(WhisperPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy + model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy, ) return policy diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 4f2a4878e..e415b5fc3 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" -def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False): +def check_state_dict_equal( + d1: OrderedDict, + d2: OrderedDict, + ignore_device: bool = True, + ignore_dtype: bool = False, +): assert len(list(d1.keys())) == len( list(d2.keys()) ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" @@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic def assert_hf_output_close( - out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 + out1: Any, + out2: Any, + ignore_keys: List[str] = None, + track_name: str = "", + atol=1e-5, + rtol=1e-5, ): """ Check if two outputs from huggingface are equal. @@ -113,7 +123,12 @@ def assert_hf_output_close( if ignore_keys is not None and k in ignore_keys: continue assert_hf_output_close( - out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + out1[k], + out2[k], + track_name=f"{track_name}.{k}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol, ) elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): # if two values are list @@ -121,12 +136,17 @@ def assert_hf_output_close( assert len(out1) == len(out2) for i in range(len(out1)): assert_hf_output_close( - out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + out1[i], + out2[i], + track_name=f"{track_name}.{i}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol, ) elif isinstance(out1, Tensor) and isinstance(out2, Tensor): if out1.shape != out2.shape: raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") - assert torch.allclose( + assert_close( out1, out2, atol=atol, rtol=rtol ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" else: diff --git a/extensions/README.md b/extensions/README.md index 6f5feb55c..b9bde7742 100644 --- a/extensions/README.md +++ b/extensions/README.md @@ -101,13 +101,13 @@ class MyExtension(_Extension): self._support_jit = True self.priority = 10 - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Return if the required hardware can be found. """ ... - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/__init__.py b/extensions/__init__.py index 9343cadda..0dbadba81 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -1,9 +1,5 @@ from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension -from .flash_attention import ( - FlashAttentionDaoCudaExtension, - FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, -) +from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension from .layernorm import LayerNormCudaExtension from .moe import MoeCudaExtension from .optimizer import FusedOptimizerCudaExtension @@ -18,7 +14,7 @@ ALL_EXTENSIONS = [ ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, FlashAttentionDaoCudaExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FlashAttentionNpuExtension, ] @@ -31,6 +27,6 @@ __all__ = [ "ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension", "FlashAttentionDaoCudaExtension", - "FlashAttentionXformersCudaExtension", + "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension", ] diff --git a/extensions/base_extension.py b/extensions/base_extension.py index c815a7f2a..0c79c0a9e 100644 --- a/extensions/base_extension.py +++ b/extensions/base_extension.py @@ -58,13 +58,13 @@ class _Extension(ABC): return cache_directory @abstractmethod - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Check if the hardware required by the kernel is available. """ @abstractmethod - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py index 35bff3b55..61c4f3ed0 100644 --- a/extensions/cpu_adam/cpu_adam_arm.py +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension): def __init__(self): super().__init__(name="cpu_adam_arm") - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # only arm allowed return platform.machine() == "aarch64" - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "aarch64" diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/cpu_adam/cpu_adam_x86.py index a38194167..9bbc8d851 100644 --- a/extensions/cpu_adam/cpu_adam_x86.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension): def __init__(self): super().__init__(name="cpu_adam_x86") - def is_hardware_available(self) -> bool: - return platform.machine() == "x86_64" and super().is_hardware_available() + def is_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_available() - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "x86_64" ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" - super().assert_hardware_compatible() + super().assert_compatible() # necessary 4 functions def sources_files(self): diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index 842cd9713..f1e0095b2 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension): This function should return a list of nvcc compilation flags for extensions. """ - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch @@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension): cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: from torch.utils.cpp_extension import CUDA_HOME if not CUDA_HOME: diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py index 18abb6191..ea5b442aa 100644 --- a/extensions/flash_attention/__init__.py +++ b/extensions/flash_attention/__init__.py @@ -1,20 +1,14 @@ from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension from .flash_attention_npu import FlashAttentionNpuExtension -from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension +from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension try: + # TODO: remove this after updating openmoe example import flash_attention # noqa HAS_FLASH_ATTN = True except: HAS_FLASH_ATTN = False -try: - import xformers # noqa - - HAS_MEM_EFF_ATTN = True -except: - HAS_MEM_EFF_ATTN = False - -__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py index 1b7f8ac47..a2f2a52f1 100644 --- a/extensions/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa + from flash_attn.bert_padding import index_first_axis, pad_input # noqa + cuda_available = torch.cuda.is_available() except: cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension): ) def load(self): - try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" - ) - ) - from typing import Optional import torch + from einops import rearrange + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.bert_padding import index_first_axis, pad_input + + def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - # check if the input is in allowed dtypes - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( + # [B, N, S, D] -> [B, S, N, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + b, s_q = q.shape[:2] + if cu_seqlens_q is not None: + # padded / padded causal + # unpad input: [B, S, N, D] -> [T, N, D] + q = _unpad_input(q, q_indices) + kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + ) + # pad output: [T, N, D] -> [B, S, N, D] + attn_output = pad_input(attn_output, q_indices, b, s_q) + else: + # causal / no attn mask + attn_output = flash_attn_func( q, k, v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out + # [B, S, N, D] -> [B, N, S, D] + return attn_output.transpose(1, 2) return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py index 58d0f9306..0e01cefa1 100644 --- a/extensions/flash_attention/flash_attention_npu.py +++ b/extensions/flash_attention/flash_attention_npu.py @@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: try: - import torch_npu # noqa + import torch_npu - return True + return hasattr(torch_npu, "npu_fusion_attention") except: return False - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension): ) def load(self): + from typing import Optional + import torch - from einops import rearrange + import torch_npu - def npu_sdpa_attention( + def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q=None, - seq_len_info_kv=None, - origin_attn_mask: torch.Tensor = None, dropout_p: float = 0.0, - scale: float = 1.0, - causal=None, - padded=None, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - The scaled dot product attention. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] - output = torch.nn.functional.scaled_dot_product_attention( + num_heads = q.size(1) + return torch_npu.npu_fusion_attention( q, k, v, - attn_mask=origin_attn_mask, - dropout_p=dropout_p, - is_causal=origin_attn_mask is None, + num_heads, + "BNSD", + atten_mask=attention_mask.bool(), scale=scale, - ) - output = rearrange(output, "b h s d -> b s (h d)") - return output + keep_prob=1 - dropout_p, + )[0] - return npu_sdpa_attention + return flash_attention diff --git a/extensions/flash_attention/flash_attention_sdpa_cuda.py b/extensions/flash_attention/flash_attention_sdpa_cuda.py new file mode 100644 index 000000000..d3323a6aa --- /dev/null +++ b/extensions/flash_attention/flash_attention_sdpa_cuda.py @@ -0,0 +1,56 @@ +from ..base_extension import _Extension + + +class FlashAttentionSdpaCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False) + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.") + + def build_jit(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.") + + def load(self): + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + return torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout_p, + scale=scale, + ) + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py deleted file mode 100644 index 27cd823de..000000000 --- a/extensions/flash_attention/flash_attention_xformers_cuda.py +++ /dev/null @@ -1,94 +0,0 @@ -from ..base_extension import _Extension - - -class FlashAttentionXformersCudaExtension(_Extension): - def __init__(self): - super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) - - def is_hardware_available(self) -> bool: - # cuda extension can only be built if cuda is available - try: - import torch - - cuda_available = torch.cuda.is_available() - except: - cuda_available = False - return cuda_available - - def assert_hardware_compatible(self) -> bool: - pass - - def build_aot(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def build_jit(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def load(self): - try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - ) - from typing import Optional - - import torch - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out - - return mem_eff_attention diff --git a/setup.py b/setup.py index ef89481e6..c16709ad1 100644 --- a/setup.py +++ b/setup.py @@ -80,8 +80,8 @@ if BUILD_EXT: for ext_cls in ALL_EXTENSIONS: ext = ext_cls() - if ext.support_aot and ext.is_hardware_available(): - ext.assert_hardware_compatible() + if ext.support_aot and ext.is_available(): + ext.assert_compatible() op_names.append(ext.name) ext_modules.append(ext.build_aot()) diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py new file mode 100644 index 000000000..f9eab132f --- /dev/null +++ b/tests/test_shardformer/test_flash_attention.py @@ -0,0 +1,147 @@ +import math +from copy import copy + +import torch +from torch.testing import assert_close + +from colossalai.kernel.kernel_loader import ( + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, +) +from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer.attn import invert_mask +from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.utils import get_current_device, set_seed + +DTYPE = [torch.float16, torch.bfloat16] +B, N, S, D = 2, 8, 256, 32 + +TOL_MAP = { + torch.float16: {"atol": 5e-4, "rtol": 2e-3}, + torch.bfloat16: {}, +} + + +def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0): + head_dim = q.size(-1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype) + attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) + attn_output = torch.matmul(attn_weights, v) + return attn_output + + +def gen_padded_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, : S // 4] = 0 + return ( + ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask), + padding_mask, + ) + + +def gen_padded_causal_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, S // 2 :] = 0 + return ( + ColoAttention.prepare_attn_kwargs( + (B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + ), + padding_mask, + ) + + +def gen_causal_kwargs(dtype: torch.dtype): + return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None + + +def gen_custom_kwargs(dtype: torch.dtype): + attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device()) + attn_mask[0, : S // 2, S // 2 :] = 0 + attn_mask[0, S // 2 :, : S // 2] = 0 + attn_mask[1, :, S // 4 :] = 0 + attn_mask = invert_mask(attn_mask).unsqueeze(1) + assert not torch.all(attn_mask != 0, dim=-1).any() + return {"attention_mask": attn_mask}, None + + +def post_process_kwargs_for_raw_attn(attn_kwargs: dict): + if "attention_mask_type" in attn_kwargs: + attn_kwargs = copy(attn_kwargs) + mask_type = attn_kwargs.pop("attention_mask_type") + attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + return attn_kwargs + + +def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None): + tols = TOL_MAP[dtype] + q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + q_flash = q.clone().detach().requires_grad_(True) + k_flash = k.clone().detach().requires_grad_(True) + v_flash = v.clone().detach().requires_grad_(True) + attn_mask = attn_kwargs.get("attention_mask", None) + ref_output = attention_ref(q, k, v, attn_mask) + output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs) + if padding_mask is not None: + # [B, Sq] -> [B, 1, Sq, 1] + padding_mask = padding_mask[:, None, :, None].logical_not() + ref_output = ref_output.masked_fill(padding_mask, 0) + output = output.masked_fill(padding_mask, 0) + assert_close(output, ref_output, **tols) + output.mean().backward() + ref_output.mean().backward() + assert_close(q.grad, q_flash.grad, **tols) + assert_close(k.grad, k_flash.grad, **tols) + assert_close(v.grad, v_flash.grad, **tols) + + +@clear_cache_before_run() +@parameterize("dtype", DTYPE) +def test_flash_attn_func(dtype: torch.dtype): + torch.backends.cudnn.deterministic = True + set_seed(0) + # (func, name, need_postprocess) + avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + for ext_cls in FlashAttentionLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) + + test_sets = { + "none": (lambda dtype: ({}, None), avail_attn_funcs), + "padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs), + "padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs), + "causal": (gen_causal_kwargs, avail_attn_funcs), + "custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs), + } + + for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items(): + attn_kwargs, padding_mask = gen_kwargs_func(dtype) + for attn_func, name, need_postprocess in attn_funcs: + print(f"{dtype}, {name}, {mask_type}") + if need_postprocess: + check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) + else: + check_attn_func(dtype, attn_func, attn_kwargs, padding_mask) + + +if __name__ == "__main__": + test_flash_attn_func() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 62d4d1bf3..85be9a242 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -31,6 +31,7 @@ def build_model( enable_jit_fused=False, enable_sequence_parallelism=False, use_lazy_init: bool = False, + dtype=torch.float32, ): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -51,7 +52,7 @@ def build_model( model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) - return org_model.cuda(), sharded_model.cuda() + return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype) def build_pipeline_model( @@ -132,7 +133,14 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c booster = Booster(plugin=plugin) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) - return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + return ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) def run_forward_backward_with_hybrid_plugin( @@ -173,7 +181,12 @@ def run_forward_backward_with_hybrid_plugin( data_iter = iter([data]) sharded_output = booster.execute_pipeline( - data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True, ) sharded_loss = sharded_output["loss"] else: @@ -313,7 +326,9 @@ def check_grad( def unwrap_model( - module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None + module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None, ): if isinstance(module, HybridParallelModule): module = module.unwrap() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index 02c15460e..2c56b0435 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -45,19 +45,51 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo "qformer.encoder.layer[0].attention.output.dense", "language_model.model.decoder.layers[0].self_attn.out_proj", ] - check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + check_grad( + blip2, + sharded_blip2, + col_layer_for_check, + atol=1e-6, + rtol=1e-5, + dim=0, + verbose=False, + ) + check_grad( + blip2, + sharded_blip2, + row_layer_for_check, + atol=1e-6, + rtol=1e-5, + dim=1, + verbose=False, + ) @parameterize("enable_fused_normalization", [True, False]) @parameterize("enable_tensor_parallelism", [True, False]) @parameterize("enable_flash_attention", [True, False]) @parameterize("enable_jit_fused", [True, False]) -def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): +def run_blip2_test( + enable_fused_normalization, + enable_tensor_parallelism, + enable_flash_attention, + enable_jit_fused, +): sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): org_model, sharded_model = build_model( - model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused + model_fn, + enable_fused_normalization, + enable_tensor_parallelism, + enable_flash_attention, + enable_jit_fused, + dtype=torch.float, ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) @@ -66,7 +98,14 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable def check_blip2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_blip2_test() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 29d3592bf..78d752b69 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -11,7 +11,6 @@ from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_all_grad_tensors, check_loss, - check_output_hidden_state, check_weight, get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, @@ -25,7 +24,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -36,7 +41,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") norm_layer_for_check = ["encoder.layers[0].input_layernorm"] - row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] + row_layer_for_check = [ + "encoder.layers[0].self_attention.query_key_value", + "embedding.word_embeddings", + ] col_layer_for_check = ["encoder.layers[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. @@ -94,8 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "ChatGLMModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong + # if org_model.__class__.__name__ == "ChatGLMModel": + # check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -143,8 +152,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, { "tp_size": 2, "pp_size": 1, @@ -159,7 +180,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_chatglm_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -193,7 +220,13 @@ def run_chatglm_test(test_config): def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -202,13 +235,27 @@ def run_chatglm_3d_test(test_config): def check_chatglm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_chatglm_test() def check_chatglm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_chatglm_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3155420f1..d59d7e4ad 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -47,10 +53,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( - gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) row_layer_grads = get_grad_tensors_for_check( - gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) norm_layer_grads = get_grad_tensors_for_check( @@ -90,7 +110,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_weight( + gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -123,14 +152,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -138,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": True, "precision": "fp32", }, @@ -167,7 +196,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_gpt2_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -202,7 +237,13 @@ def run_gpt2_test(test_config): def run_gpt2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -211,13 +252,27 @@ def run_gpt2_3d_test(test_config): def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gpt2_test() def check_gpt2_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gpt2_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py index c83eaaa09..009202a0d 100644 --- a/tests/test_shardformer/test_model/test_shard_gptj.py +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -46,11 +52,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + gptj, + sharded_gptj, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) row_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + gptj, + sharded_gptj, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,7 +97,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + check_weight( + gptj, + sharded_gptj, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -110,14 +139,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -125,7 +154,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, #'use_lazy_init': True, "precision": "fp32", }, @@ -154,7 +183,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_gptj_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -189,7 +224,13 @@ def run_gptj_test(test_config): def run_gptj_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -198,15 +239,30 @@ def run_gptj_3d_test(test_config): def check_gptj(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gptj_test() def check_gptj_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gptj_3d_test() + @pytest.mark.skip("TODO check_gptj has something wrong.") @pytest.mark.dist @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c7edcfb35..126ff23a9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -112,7 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index d21ab264d..523ed879b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -29,7 +29,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -39,7 +45,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, opt_model = unwrap_model(org_model, "OPTModel", "model") shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model") - row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens' + row_layer_for_check = [ + "decoder.layers[0].self_attn.q_proj", + "decoder.embed_tokens", + ] # 'decoder.embed_tokens' col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"] # Save gradient tensors for comparison between the original model and the sharded model. @@ -50,10 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 4e-2, 4e-2 row_layer_grads = get_grad_tensors_for_check( - opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) col_layer_grads = get_grad_tensors_for_check( - opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -80,7 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 check_weight( - opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) # check grads @@ -110,8 +140,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, { "tp_size": 2, "pp_size": 1, @@ -135,7 +177,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_opt_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -169,7 +217,13 @@ def run_opt_test(test_config): def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -178,13 +232,27 @@ def run_opt_3d_test(test_config): def check_OPTModel(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_opt_test() def check_opt_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_opt_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 22c201458..9b22d54d7 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -71,7 +77,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + check_weight( + t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -104,7 +119,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -117,7 +132,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, @@ -144,7 +158,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_t5_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): # skip 4-stage pp test for t5_encoder if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": continue @@ -185,7 +205,13 @@ def run_t5_test(test_config): def run_t5_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -194,13 +220,27 @@ def run_t5_3d_test(test_config): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_t5_test() def check_t5_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_t5_3d_test() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py deleted file mode 100644 index 3ec170004..000000000 --- a/tests/test_utils/test_flash_attention.py +++ /dev/null @@ -1,167 +0,0 @@ -import math - -import pytest -import torch -from einops import rearrange - -from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN -from colossalai.testing import clear_cache_before_run, parameterize - -if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - -DTYPE = [torch.float16, torch.bfloat16, torch.float32] - - -def attention_ref(q, k, v, attn_mask=None, causal=False): - """ - attention output of the control group - """ - dtype_og = q.dtype - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - d = q.shape[-1] - scale = 1.0 / math.sqrt(d) - scores = torch.einsum("bthd,bshd->bhts", q * scale, k) - - if attn_mask is not None: - scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - output = rearrange(output, "b s h d -> b s (h d)") - - # Modify the data at the positions of the mask to 0 - if attn_mask is not None: - output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) - - return output.to(dtype=dtype_og) - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_gpt(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] - mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=True) - - # check gradients - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_bert(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - # attention mask of shape [B, S] with zero padding to max length S - mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_no_mask(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, None, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_cross_attention(proj_shape, dtype, dropout): - (B, S, T, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) - - assert list(y.shape) == [B, T, D] - - out_ref = attention_ref(q, k, v, None, causal=True) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"