mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix llama flash attention forward (#5777)
parent
b45000f839
commit
50b4c8e8cf
|
@ -494,7 +494,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
|||
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
|
|
Loading…
Reference in New Issue