Browse Source

[hotfix] fix llama flash attention forward (#5777)

pull/5782/head
flybird11111 6 months ago committed by GitHub
parent
commit
50b4c8e8cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      colossalai/shardformer/modeling/llama.py

1
colossalai/shardformer/modeling/llama.py

@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)

Loading…
Cancel
Save