mirror of https://github.com/hpcaitech/ColossalAI
parent
25c57b9fb4
commit
38b792aab2
|
@ -0,0 +1,3 @@
|
|||
from .mha import ColoAttention
|
||||
|
||||
__all__ = ['ColoAttention']
|
|
@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
|||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
|
Loading…
Reference in New Issue