diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
index 2b86763f1..7bd646d39 100644
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ b/colossalai/kernel/cuda_native/flash_attention.py
@@ -48,6 +48,13 @@ except ImportError:
     HAS_FLASH_ATTN = False
     print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
 
+try:
+    from xformers.ops.fmha import memory_efficient_attention
+    HAS_MEM_EFF_ATTN = True
+except ImportError:
+    HAS_MEM_EFF_ATTN = False
+    print('please install xformers from https://github.com/facebookresearch/xformers')
+
 if HAS_TRITON:
 
     @triton.jit
@@ -497,3 +504,22 @@ if HAS_FLASH_ATTN:
                                      device=k.device)
         return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
                                         causal)
+
+
+if HAS_MEM_EFF_ATTN:
+
+    from einops import rearrange
+    from xformers.ops.fmha import LowerTriangularMask
+
+    class MemoryEfficientAttention(torch.nn.Module):
+
+        def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
+            super().__init__()
+            attention_head_size = hidden_size // num_attention_heads
+            self.scale = 1 / attention_head_size**0.5
+            self.dropout = attention_dropout
+
+        def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
+            context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
+            context = rearrange(context, 'b s h d -> b s (h d)')
+            return context
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 9d2ee8a18..58e3b21d9 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -2,7 +2,7 @@ import pytest
 import torch
 from einops import rearrange
 
-from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON
+from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
 
 if HAS_FLASH_ATTN:
     from colossalai.kernel.cuda_native.flash_attention import (
@@ -15,6 +15,9 @@ if HAS_FLASH_ATTN:
 if HAS_TRITON:
     from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
 
+if HAS_MEM_EFF_ATTN:
+    from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
+
 
 def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
     M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
@@ -124,5 +127,20 @@ def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
     out.backward(dout)
 
 
+@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
+def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
+    attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
+
+    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+
+    out = attn(q, k, v, attention_mask=LowerTriangularMask())
+
+    dout = torch.rand_like(out)
+    out.backward(dout)
+
+
 if __name__ == '__main__':
     test_flash_attention(3, 4, 2, 16)