mirror of https://github.com/InternLM/InternLM
Fit to flash attention 1.0.5.
parent
a3580acb6c
commit
b38ba5dad2
|
@ -10,9 +10,19 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_unpadded_kvpacked_func
|
from flash_attn import flash_attn_unpadded_func
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
|
try:
|
||||||
|
from flash_attn import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from flash_attn import (
|
||||||
|
flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please check your flash_attn version >= 1.0.5.")
|
||||||
|
|
||||||
from flash_attn.modules.mha import (
|
from flash_attn.modules.mha import (
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
|
@ -285,7 +295,7 @@ class MHA(nn.Module):
|
||||||
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
total_kv = total_kv.to(torch.bfloat16)
|
total_kv = total_kv.to(torch.bfloat16)
|
||||||
|
|
||||||
output = flash_attn_unpadded_kvpacked_func(
|
output = flash_attn_unpadded_func(
|
||||||
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
||||||
).to(x.dtype)
|
).to(x.dtype)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue