[hotfix] fix llama flash attention forward (#5777)

pull/5782/head
flybird11111 2024-06-05 10:56:47 +08:00 committed by GitHub
parent b45000f839
commit 50b4c8e8cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 0 additions and 1 deletions

View File

@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)