[npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix
pull/5237/head
Xuanlei Zhao 2023-11-30 14:21:30 +08:00 committed by GitHub
parent f4e72c9992
commit d6df19bae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 264 additions and 3 deletions

View File

@ -29,7 +29,6 @@ except ImportError:
HAS_FLASH_ATTN = False
if HAS_FLASH_ATTN:
pass
from .utils import SeqLenInfo

View File

@ -44,6 +44,7 @@ class ColoAttention(torch.nn.Module):
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):

View File

@ -0,0 +1,3 @@
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]

View File

@ -0,0 +1,3 @@
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]

View File

@ -0,0 +1,80 @@
import math
from typing import Optional
import torch
from .sdpa_attn import npu_sdpa_attention
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
class NPUColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
super().__init__()
try:
import torch_npu # noqa
except ImportError:
raise Exception("torch_npu is not installed.")
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: int = None,
bias: Optional[torch.Tensor] = None,
):
"""
Implement the scaled dot product attention with softmax.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
assert (
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
assert (
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
assert bias is None, "bias is not supported in npu colo attention"
causal = attn_mask_type is not None and attn_mask_type.value > 1
if HAS_NPU_TRIANGLE_ATTENTION:
from .triangle_attn import npu_triangle_attention
attn_fn = npu_triangle_attention
else:
attn_fn = npu_sdpa_attention
out = attn_fn(
query,
key,
value,
attn_mask=attn_mask,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
is_causal=causal,
)
return out

View File

@ -0,0 +1,41 @@
import torch
from einops import rearrange
def npu_sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor = None,
origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0,
is_causal: bool = True,
):
"""
The scaled dot product attention.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output

View File

@ -0,0 +1,115 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from einops import rearrange
HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
logging.warning("Import torch_npu Error.")
if HAS_NPU_TRIANGLE_ATTENTION:
def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor = None,
origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0,
is_causal: bool = True,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks
Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)
if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer

View File

@ -280,3 +280,21 @@ def create_randomizer_with_offset(
Randomizer.increment_index()
return Randomizer(seed=base_seed)
def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType
if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")
return AttnMaskType, AttentionKernel

View File

@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@ -404,7 +405,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
AttnMaskType, ColoAttention = get_attention_kernel()
llama_version = 2
try:
@ -468,7 +469,7 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output)