mirror of https://github.com/hpcaitech/ColossalAI
[npu] support triangle attention for llama (#5130)
* update fused attn * update spda * tri attn * update triangle * import * fix * fixpull/5237/head
parent
f4e72c9992
commit
d6df19bae7
|
@ -29,7 +29,6 @@ except ImportError:
|
|||
HAS_FLASH_ATTN = False
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
pass
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ class ColoAttention(torch.nn.Module):
|
|||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .mha import NPUColoAttention
|
||||
|
||||
__all__ = ["NPUColoAttention"]
|
|
@ -0,0 +1,3 @@
|
|||
from .mha import NPUColoAttention
|
||||
|
||||
__all__ = ["NPUColoAttention"]
|
|
@ -0,0 +1,80 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .sdpa_attn import npu_sdpa_attention
|
||||
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
|
||||
|
||||
|
||||
class NPUColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
except ImportError:
|
||||
raise Exception("torch_npu is not installed.")
|
||||
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: int = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Implement the scaled dot product attention with softmax.
|
||||
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1.
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
assert (
|
||||
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
|
||||
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
|
||||
assert (
|
||||
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
|
||||
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
|
||||
assert bias is None, "bias is not supported in npu colo attention"
|
||||
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||
from .triangle_attn import npu_triangle_attention
|
||||
|
||||
attn_fn = npu_triangle_attention
|
||||
else:
|
||||
attn_fn = npu_sdpa_attention
|
||||
|
||||
out = attn_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
origin_attn_mask=origin_attn_mask,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
is_causal=causal,
|
||||
)
|
||||
return out
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def npu_sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
scale: float = 1.0,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
"""
|
||||
The scaled dot product attention.
|
||||
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1.
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
|
||||
output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=origin_attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=origin_attn_mask is None,
|
||||
scale=scale,
|
||||
)
|
||||
output = rearrange(output, "b h s d -> b s (h d)")
|
||||
return output
|
|
@ -0,0 +1,115 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
HAS_NPU_TRIANGLE_ATTENTION = False
|
||||
try:
|
||||
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
|
||||
|
||||
HAS_NPU_TRIANGLE_ATTENTION = True
|
||||
except ImportError:
|
||||
logging.warning("Import torch_npu Error.")
|
||||
|
||||
|
||||
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||
|
||||
def npu_triangle_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
scale: float = 1.0,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = True,
|
||||
block_size=512,
|
||||
):
|
||||
"""
|
||||
The triangle attention reduces the attention calculation of the mask
|
||||
part by dividing the q, k, and v matrices into blocks
|
||||
|
||||
Arguments:
|
||||
block_size: The size of the inverted triangle block, the default is 512,
|
||||
the smaller the block_size, the more calculations will be reduced,
|
||||
but the number of small operators will be increased
|
||||
masked_softmax_func: mask function to be applied.
|
||||
dropout_func: dropout function to be applied.
|
||||
"""
|
||||
|
||||
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
|
||||
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
|
||||
cur_sim = torch.matmul(q_layer, k_layer)
|
||||
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
|
||||
# attention dropout
|
||||
if dropout_p > 0:
|
||||
attention_probs = torch.nn.functional.dropout(
|
||||
attention_probs, p=dropout_p, training=attention_probs.require_grad
|
||||
)
|
||||
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
|
||||
context_layer_tmp = torch.matmul(attention_probs, v_layer)
|
||||
return context_layer_tmp
|
||||
|
||||
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
|
||||
origin_attn_mask = origin_attn_mask.to(torch.bool)
|
||||
# input shape: [b, hn, sq, hd]
|
||||
bsz, head_num, sequence_len, head_dim = k.shape
|
||||
sparse_groups = sequence_len // block_size
|
||||
# Determine whether blocks size can be divided by sequence_length
|
||||
divisible_flag = sequence_len == block_size * sparse_groups
|
||||
k = k.transpose(2, 3).contiguous()
|
||||
if divisible_flag:
|
||||
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
|
||||
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
|
||||
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
|
||||
else:
|
||||
seq_tmp = block_size * sparse_groups
|
||||
q_last = q[:, :, seq_tmp:, :].contiguous()
|
||||
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
|
||||
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
|
||||
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||
context_list_tmp, k_tmp, v_tmp = [], (), ()
|
||||
for i in range(sparse_groups):
|
||||
# compute slice shape of q k v for each loop
|
||||
q_begin, q_end = i * block_size, (i + 1) * block_size
|
||||
kv_begin, kv_end = 0, (i + 1) * block_size
|
||||
q_tmp = q_tmp_layers[i]
|
||||
# slice k and v
|
||||
if i == 0:
|
||||
k_tmp = k_tmp_layers[i].contiguous()
|
||||
v_tmp = v_tmp_layers[i].contiguous()
|
||||
else:
|
||||
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
|
||||
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
|
||||
|
||||
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
|
||||
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
|
||||
context_list_tmp.append(context_layer_tmp)
|
||||
|
||||
if not divisible_flag:
|
||||
# circumstances that cannot be divisible
|
||||
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
|
||||
context_list_tmp.append(context_layer_tmp)
|
||||
context_layer = torch.cat(context_list_tmp, 2)
|
||||
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
|
||||
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
|
||||
# =========================
|
||||
# Context layer. [b, sq, hp]
|
||||
# =========================
|
||||
return context_layer
|
|
@ -280,3 +280,21 @@ def create_randomizer_with_offset(
|
|||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
||||
|
||||
def get_attention_kernel():
|
||||
"""
|
||||
Get the attention kernel based on the device type.
|
||||
"""
|
||||
from colossalai.kernel.cuda_native import AttnMaskType
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
|
||||
else:
|
||||
try:
|
||||
torch.npu.is_available()
|
||||
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
|
||||
except:
|
||||
raise Exception("No available device for attention kernel!")
|
||||
|
||||
return AttnMaskType, AttentionKernel
|
||||
|
|
|
@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.utils import get_attention_kernel
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
@ -404,7 +405,7 @@ class LlamaPipelineForwards:
|
|||
def get_llama_flash_attention_forward():
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
AttnMaskType, ColoAttention = get_attention_kernel()
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
|
@ -468,7 +469,7 @@ def get_llama_flash_attention_forward():
|
|||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
Loading…
Reference in New Issue