mirror of https://github.com/InternLM/InternLM
Fit to flash attention 1.0
parent
a35ce4c888
commit
a3580acb6c
|
@ -8,7 +8,12 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn import flash_attn_unpadded_kvpacked_func
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_unpadded_kvpacked_func
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
|
||||||
|
|
||||||
from flash_attn.modules.mha import (
|
from flash_attn.modules.mha import (
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FlashCrossAttention,
|
FlashCrossAttention,
|
||||||
|
|
Loading…
Reference in New Issue