mirror of https://github.com/InternLM/InternLM
Fit to flash attention 1.0
parent
a35ce4c888
commit
a3580acb6c
|
@ -8,7 +8,12 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_unpadded_kvpacked_func
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_unpadded_kvpacked_func
|
||||
except ImportError:
|
||||
from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
|
||||
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
|
|
Loading…
Reference in New Issue