pull/6023/head
wangbluo 3 months ago
parent 6aface9316
commit 698c8b9804

@ -538,6 +538,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()

Loading…
Cancel
Save