diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d6f10ffaf..2b30074a5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states)