diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index d6f10ffaf..2b30074a5 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
 
         if sp_mode in ["split_gather", "ring"]:
             q_len *= sp_size
-        assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
 
         query_states = self.q_proj(hidden_states)
         key_states = self.k_proj(hidden_states)