mirror of https://github.com/hpcaitech/ColossalAI
[npu] support triangle attention for llama (#5130)
* update fused attn * update spda * tri attn * update triangle * import * fix * fixpull/5237/head
parent
f4e72c9992
commit
d6df19bae7
|
@ -29,7 +29,6 @@ except ImportError:
|
||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
|
|
||||||
if HAS_FLASH_ATTN:
|
if HAS_FLASH_ATTN:
|
||||||
pass
|
|
||||||
|
|
||||||
from .utils import SeqLenInfo
|
from .utils import SeqLenInfo
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ class ColoAttention(torch.nn.Module):
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||||
attn_mask_type: Optional[AttnMaskType] = None,
|
attn_mask_type: Optional[AttnMaskType] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .mha import NPUColoAttention
|
||||||
|
|
||||||
|
__all__ = ["NPUColoAttention"]
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .mha import NPUColoAttention
|
||||||
|
|
||||||
|
__all__ = ["NPUColoAttention"]
|
|
@ -0,0 +1,80 @@
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .sdpa_attn import npu_sdpa_attention
|
||||||
|
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
|
||||||
|
|
||||||
|
|
||||||
|
class NPUColoAttention(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise Exception("torch_npu is not installed.")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
embed_dim % num_heads == 0
|
||||||
|
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||||
|
if scale is not None:
|
||||||
|
self.scale = scale
|
||||||
|
else:
|
||||||
|
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
attn_mask_type: int = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Implement the scaled dot product attention with softmax.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
q: (batch, q_seqlen, nheads, headdim)
|
||||||
|
k: (batch, kv_seqlen, nheads, headdim)
|
||||||
|
v: (batch, kv_seqlen, nheads, headdim)
|
||||||
|
batch_size: int.
|
||||||
|
seq_len: int.
|
||||||
|
dropout_p: float. Dropout probability.
|
||||||
|
scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1.
|
||||||
|
Return:
|
||||||
|
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
|
||||||
|
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
|
||||||
|
assert (
|
||||||
|
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
|
||||||
|
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
|
||||||
|
assert bias is None, "bias is not supported in npu colo attention"
|
||||||
|
|
||||||
|
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||||
|
|
||||||
|
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||||
|
from .triangle_attn import npu_triangle_attention
|
||||||
|
|
||||||
|
attn_fn = npu_triangle_attention
|
||||||
|
else:
|
||||||
|
attn_fn = npu_sdpa_attention
|
||||||
|
|
||||||
|
out = attn_fn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
origin_attn_mask=origin_attn_mask,
|
||||||
|
dropout_p=self.dropout,
|
||||||
|
scale=self.scale,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
return out
|
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def npu_sdpa_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
origin_attn_mask: torch.Tensor = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The scaled dot product attention.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
q: (batch, q_seqlen, nheads, headdim)
|
||||||
|
k: (batch, kv_seqlen, nheads, headdim)
|
||||||
|
v: (batch, kv_seqlen, nheads, headdim)
|
||||||
|
batch_size: int.
|
||||||
|
seq_len: int.
|
||||||
|
dropout_p: float. Dropout probability.
|
||||||
|
scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1.
|
||||||
|
Return:
|
||||||
|
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||||
|
"""
|
||||||
|
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
|
||||||
|
output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask=origin_attn_mask,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
is_causal=origin_attn_mask is None,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "b h s d -> b s (h d)")
|
||||||
|
return output
|
|
@ -0,0 +1,115 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
HAS_NPU_TRIANGLE_ATTENTION = False
|
||||||
|
try:
|
||||||
|
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
|
||||||
|
|
||||||
|
HAS_NPU_TRIANGLE_ATTENTION = True
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("Import torch_npu Error.")
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||||
|
|
||||||
|
def npu_triangle_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
origin_attn_mask: torch.Tensor = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = True,
|
||||||
|
block_size=512,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The triangle attention reduces the attention calculation of the mask
|
||||||
|
part by dividing the q, k, and v matrices into blocks
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
block_size: The size of the inverted triangle block, the default is 512,
|
||||||
|
the smaller the block_size, the more calculations will be reduced,
|
||||||
|
but the number of small operators will be increased
|
||||||
|
masked_softmax_func: mask function to be applied.
|
||||||
|
dropout_func: dropout function to be applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
|
||||||
|
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
|
||||||
|
cur_sim = torch.matmul(q_layer, k_layer)
|
||||||
|
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
|
||||||
|
# attention dropout
|
||||||
|
if dropout_p > 0:
|
||||||
|
attention_probs = torch.nn.functional.dropout(
|
||||||
|
attention_probs, p=dropout_p, training=attention_probs.require_grad
|
||||||
|
)
|
||||||
|
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
|
||||||
|
context_layer_tmp = torch.matmul(attention_probs, v_layer)
|
||||||
|
return context_layer_tmp
|
||||||
|
|
||||||
|
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
|
||||||
|
origin_attn_mask = origin_attn_mask.to(torch.bool)
|
||||||
|
# input shape: [b, hn, sq, hd]
|
||||||
|
bsz, head_num, sequence_len, head_dim = k.shape
|
||||||
|
sparse_groups = sequence_len // block_size
|
||||||
|
# Determine whether blocks size can be divided by sequence_length
|
||||||
|
divisible_flag = sequence_len == block_size * sparse_groups
|
||||||
|
k = k.transpose(2, 3).contiguous()
|
||||||
|
if divisible_flag:
|
||||||
|
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
|
||||||
|
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
|
||||||
|
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
|
||||||
|
else:
|
||||||
|
seq_tmp = block_size * sparse_groups
|
||||||
|
q_last = q[:, :, seq_tmp:, :].contiguous()
|
||||||
|
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
|
||||||
|
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||||
|
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
|
||||||
|
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||||
|
context_list_tmp, k_tmp, v_tmp = [], (), ()
|
||||||
|
for i in range(sparse_groups):
|
||||||
|
# compute slice shape of q k v for each loop
|
||||||
|
q_begin, q_end = i * block_size, (i + 1) * block_size
|
||||||
|
kv_begin, kv_end = 0, (i + 1) * block_size
|
||||||
|
q_tmp = q_tmp_layers[i]
|
||||||
|
# slice k and v
|
||||||
|
if i == 0:
|
||||||
|
k_tmp = k_tmp_layers[i].contiguous()
|
||||||
|
v_tmp = v_tmp_layers[i].contiguous()
|
||||||
|
else:
|
||||||
|
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
|
||||||
|
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
|
||||||
|
|
||||||
|
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
|
||||||
|
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
|
||||||
|
context_list_tmp.append(context_layer_tmp)
|
||||||
|
|
||||||
|
if not divisible_flag:
|
||||||
|
# circumstances that cannot be divisible
|
||||||
|
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
|
||||||
|
context_list_tmp.append(context_layer_tmp)
|
||||||
|
context_layer = torch.cat(context_list_tmp, 2)
|
||||||
|
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
|
||||||
|
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
|
||||||
|
# =========================
|
||||||
|
# Context layer. [b, sq, hp]
|
||||||
|
# =========================
|
||||||
|
return context_layer
|
|
@ -280,3 +280,21 @@ def create_randomizer_with_offset(
|
||||||
Randomizer.increment_index()
|
Randomizer.increment_index()
|
||||||
|
|
||||||
return Randomizer(seed=base_seed)
|
return Randomizer(seed=base_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_kernel():
|
||||||
|
"""
|
||||||
|
Get the attention kernel based on the device type.
|
||||||
|
"""
|
||||||
|
from colossalai.kernel.cuda_native import AttnMaskType
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
torch.npu.is_available()
|
||||||
|
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
|
||||||
|
except:
|
||||||
|
raise Exception("No available device for attention kernel!")
|
||||||
|
|
||||||
|
return AttnMaskType, AttentionKernel
|
||||||
|
|
|
@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer.utils import get_attention_kernel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||||
|
@ -404,7 +405,7 @@ class LlamaPipelineForwards:
|
||||||
def get_llama_flash_attention_forward():
|
def get_llama_flash_attention_forward():
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
AttnMaskType, ColoAttention = get_attention_kernel()
|
||||||
|
|
||||||
llama_version = 2
|
llama_version = 2
|
||||||
try:
|
try:
|
||||||
|
@ -468,7 +469,7 @@ def get_llama_flash_attention_forward():
|
||||||
|
|
||||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
Loading…
Reference in New Issue