[coloattention] fix import error (#4380)

fixed an import error
pull/4377/head
flybird1111 2023-08-04 16:28:41 +08:00 committed by GitHub
parent 25c57b9fb4
commit 38b792aab2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 1 deletions

View File

@ -0,0 +1,3 @@
from .mha import ColoAttention
__all__ = ['ColoAttention']

View File

@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.mha.mha import ColoAttention
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]