From 38b792aab2cf6e33f1693489eecbff622dff2c35 Mon Sep 17 00:00:00 2001 From: flybird1111 <1829166702@qq.com> Date: Fri, 4 Aug 2023 16:28:41 +0800 Subject: [PATCH] [coloattention] fix import error (#4380) fixed an import error --- colossalai/kernel/cuda_native/mha/__init__.py | 3 +++ tests/test_utils/test_flash_attention.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 colossalai/kernel/cuda_native/mha/__init__.py diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py new file mode 100644 index 000000000..21fddd512 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/__init__.py @@ -0,0 +1,3 @@ +from .mha import ColoAttention + +__all__ = ['ColoAttention'] diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index d41ccd832..fbcc45265 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native.mha.mha import ColoAttention + from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType DTYPE = [torch.float16, torch.bfloat16, torch.float32]