From d6df19bae7cdb9e116c1f218a4465855623c80b1 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:21:30 +0800 Subject: [PATCH] [npu] support triangle attention for llama (#5130) * update fused attn * update spda * tri attn * update triangle * import * fix * fix --- .../kernel/cuda_native/mha/flash_attn_2.py | 1 - colossalai/kernel/cuda_native/mha/mha.py | 1 + colossalai/kernel/npu/__init__.py | 3 + colossalai/kernel/npu/mha/__init__.py | 3 + colossalai/kernel/npu/mha/mha.py | 80 ++++++++++++ colossalai/kernel/npu/mha/sdpa_attn.py | 41 +++++++ colossalai/kernel/npu/mha/triangle_attn.py | 115 ++++++++++++++++++ colossalai/shardformer/layer/utils.py | 18 +++ colossalai/shardformer/modeling/llama.py | 5 +- 9 files changed, 264 insertions(+), 3 deletions(-) create mode 100644 colossalai/kernel/npu/__init__.py create mode 100644 colossalai/kernel/npu/mha/__init__.py create mode 100644 colossalai/kernel/npu/mha/mha.py create mode 100644 colossalai/kernel/npu/mha/sdpa_attn.py create mode 100644 colossalai/kernel/npu/mha/triangle_attn.py diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py index 9ee83915b..de2ccaa49 100644 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -29,7 +29,6 @@ except ImportError: HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: - pass from .utils import SeqLenInfo diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py index 1c778439d..b56d37cf0 100644 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -44,6 +44,7 @@ class ColoAttention(torch.nn.Module): key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, attn_mask_type: Optional[AttnMaskType] = None, bias: Optional[torch.Tensor] = None, ): diff --git a/colossalai/kernel/npu/__init__.py b/colossalai/kernel/npu/__init__.py new file mode 100644 index 000000000..6a02c7055 --- /dev/null +++ b/colossalai/kernel/npu/__init__.py @@ -0,0 +1,3 @@ +from .mha import NPUColoAttention + +__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/__init__.py b/colossalai/kernel/npu/mha/__init__.py new file mode 100644 index 000000000..6a02c7055 --- /dev/null +++ b/colossalai/kernel/npu/mha/__init__.py @@ -0,0 +1,3 @@ +from .mha import NPUColoAttention + +__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py new file mode 100644 index 000000000..ac982384e --- /dev/null +++ b/colossalai/kernel/npu/mha/mha.py @@ -0,0 +1,80 @@ +import math +from typing import Optional + +import torch + +from .sdpa_attn import npu_sdpa_attention +from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION + + +class NPUColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None): + super().__init__() + + try: + import torch_npu # noqa + except ImportError: + raise Exception("torch_npu is not installed.") + + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: int = None, + bias: Optional[torch.Tensor] = None, + ): + """ + Implement the scaled dot product attention with softmax. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + assert ( + len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 + ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" + assert ( + query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu" + ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" + assert bias is None, "bias is not supported in npu colo attention" + + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + if HAS_NPU_TRIANGLE_ATTENTION: + from .triangle_attn import npu_triangle_attention + + attn_fn = npu_triangle_attention + else: + attn_fn = npu_sdpa_attention + + out = attn_fn( + query, + key, + value, + attn_mask=attn_mask, + origin_attn_mask=origin_attn_mask, + dropout_p=self.dropout, + scale=self.scale, + is_causal=causal, + ) + return out diff --git a/colossalai/kernel/npu/mha/sdpa_attn.py b/colossalai/kernel/npu/mha/sdpa_attn.py new file mode 100644 index 000000000..2af1dbae2 --- /dev/null +++ b/colossalai/kernel/npu/mha/sdpa_attn.py @@ -0,0 +1,41 @@ +import torch +from einops import rearrange + + +def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor = None, + origin_attn_mask: torch.Tensor = None, + scale: float = 1.0, + dropout_p: float = 0.0, + is_causal: bool = True, +): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py new file mode 100644 index 000000000..619076d5f --- /dev/null +++ b/colossalai/kernel/npu/mha/triangle_attn.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from einops import rearrange + +HAS_NPU_TRIANGLE_ATTENTION = False +try: + from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax + + HAS_NPU_TRIANGLE_ATTENTION = True +except ImportError: + logging.warning("Import torch_npu Error.") + + +if HAS_NPU_TRIANGLE_ATTENTION: + + def npu_triangle_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor = None, + origin_attn_mask: torch.Tensor = None, + scale: float = 1.0, + dropout_p: float = 0.0, + is_causal: bool = True, + block_size=512, + ): + """ + The triangle attention reduces the attention calculation of the mask + part by dividing the q, k, and v matrices into blocks + + Arguments: + block_size: The size of the inverted triangle block, the default is 512, + the smaller the block_size, the more calculations will be reduced, + but the number of small operators will be increased + masked_softmax_func: mask function to be applied. + dropout_func: dropout function to be applied. + """ + + def compute_attn(q_layer, k_layer, v_layer, mask_tmp): + # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] + cur_sim = torch.matmul(q_layer, k_layer) + attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) + # attention dropout + if dropout_p > 0: + attention_probs = torch.nn.functional.dropout( + attention_probs, p=dropout_p, training=attention_probs.require_grad + ) + # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] + context_layer_tmp = torch.matmul(attention_probs, v_layer) + return context_layer_tmp + + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] + origin_attn_mask = origin_attn_mask.to(torch.bool) + # input shape: [b, hn, sq, hd] + bsz, head_num, sequence_len, head_dim = k.shape + sparse_groups = sequence_len // block_size + # Determine whether blocks size can be divided by sequence_length + divisible_flag = sequence_len == block_size * sparse_groups + k = k.transpose(2, 3).contiguous() + if divisible_flag: + q_tmp_layers = torch.chunk(q, sparse_groups, 2) + k_tmp_layers = torch.chunk(k, sparse_groups, 3) + v_tmp_layers = torch.chunk(v, sparse_groups, 2) + else: + seq_tmp = block_size * sparse_groups + q_last = q[:, :, seq_tmp:, :].contiguous() + mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() + q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) + k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) + v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) + context_list_tmp, k_tmp, v_tmp = [], (), () + for i in range(sparse_groups): + # compute slice shape of q k v for each loop + q_begin, q_end = i * block_size, (i + 1) * block_size + kv_begin, kv_end = 0, (i + 1) * block_size + q_tmp = q_tmp_layers[i] + # slice k and v + if i == 0: + k_tmp = k_tmp_layers[i].contiguous() + v_tmp = v_tmp_layers[i].contiguous() + else: + k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() + v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() + + mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() + context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) + context_list_tmp.append(context_layer_tmp) + + if not divisible_flag: + # circumstances that cannot be divisible + context_layer_tmp = compute_attn(q_last, k, v, mask_last) + context_list_tmp.append(context_layer_tmp) + context_layer = torch.cat(context_list_tmp, 2) + new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) + context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) + # ========================= + # Context layer. [b, sq, hp] + # ========================= + return context_layer diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4b6343adc..55683b227 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -280,3 +280,21 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) + + +def get_attention_kernel(): + """ + Get the attention kernel based on the device type. + """ + from colossalai.kernel.cuda_native import AttnMaskType + + if torch.cuda.is_available(): + from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel + else: + try: + torch.npu.is_available() + from colossalai.kernel.npu import NPUColoAttention as AttentionKernel + except: + raise Exception("No available device for attention kernel!") + + return AttnMaskType, AttentionKernel diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 616c9220f..c3de197c4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import get_attention_kernel try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -404,7 +405,7 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + AttnMaskType, ColoAttention = get_attention_kernel() llama_version = 2 try: @@ -468,7 +469,7 @@ def get_llama_flash_attention_forward(): attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output)