[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
pull/4445/head
flybird1111 2023-08-09 14:32:19 +08:00 committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@ -1,8 +1,9 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
'AttnMaskType'
]

View File

@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import ColoAttention
def forward(
self: Blip2Attention,

View File

@ -19,7 +19,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
@ -126,7 +126,6 @@ def get_jit_fused_glm_block_forward():
return forward
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.

View File

@ -674,7 +674,7 @@ def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""

View File

@ -392,7 +392,7 @@ def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: LlamaAttention,

View File

@ -8,7 +8,7 @@ def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,

View File

@ -342,7 +342,7 @@ def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native.flash_attention import ColoAttention
from colossalai.kernel.cuda_native import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)

View File

@ -8,7 +8,7 @@ def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()

View File

@ -13,7 +13,6 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
FLASH_DTYPE = [torch.float16, torch.bfloat16]
def attention_ref(q, k, v, attn_mask=None, causal=False):