From 50b4c8e8cfe9ebb326433ea45d6ad76dbf9d4a54 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 5 Jun 2024 10:56:47 +0800 Subject: [PATCH] [hotfix] fix llama flash attention forward (#5777) --- colossalai/shardformer/modeling/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d6f10ffaf..2b30074a5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states)