mirror of https://github.com/hpcaitech/ColossalAI
parent
25c57b9fb4
commit
38b792aab2
|
@ -0,0 +1,3 @@
|
||||||
|
from .mha import ColoAttention
|
||||||
|
|
||||||
|
__all__ = ['ColoAttention']
|
|
@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
from colossalai.testing import clear_cache_before_run, parameterize
|
||||||
|
|
||||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||||
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
|
from colossalai.kernel.cuda_native import ColoAttention
|
||||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||||
|
|
||||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
Loading…
Reference in New Issue