mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] update shardformer to use flash attention 2 (#4392)
* cherry-pick flash attention 2 cherry-pick flash attention 2 * [shardformer] update shardformer to use flash attention 2 [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fix [shardformer] update shardformer to use flash attention 2, fixpull/4445/head
parent
ed4c448488
commit
7a3dfd0c64
|
@ -1,8 +1,9 @@
|
|||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .mha.mha import ColoAttention
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
__all__ = [
|
||||
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
|
||||
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
|
||||
'AttnMaskType'
|
||||
]
|
||||
|
|
|
@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward():
|
|||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Blip2Attention,
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||
|
||||
def get_flash_core_attention_forward():
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
|
||||
|
@ -126,7 +126,6 @@ def get_jit_fused_glm_block_forward():
|
|||
return forward
|
||||
|
||||
|
||||
|
||||
class ChatGLMPipelineForwards:
|
||||
'''
|
||||
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
|
||||
|
|
|
@ -674,7 +674,7 @@ def get_gpt2_flash_attention_forward():
|
|||
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
def split_heads(tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
|
|
|
@ -392,7 +392,7 @@ def get_llama_flash_attention_forward():
|
|||
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
|
|
|
@ -8,7 +8,7 @@ def get_opt_flash_attention_forward():
|
|||
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: OPTAttention,
|
||||
|
|
|
@ -342,7 +342,7 @@ def get_vit_flash_self_attention_forward():
|
|||
|
||||
from transformers.models.vit.modeling_vit import ViTSelfAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import ColoAttention
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
|
||||
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
|
||||
|
|
|
@ -8,7 +8,7 @@ def get_whisper_flash_attention_forward():
|
|||
|
||||
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
|
||||
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
|
||||
|
|
|
@ -13,7 +13,6 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
|||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
FLASH_DTYPE = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def attention_ref(q, k, v, attn_mask=None, causal=False):
|
||||
|
|
Loading…
Reference in New Issue