From 698c8b98046a7603d9c2842ae494fa4cffa81b5c Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 21 Aug 2024 03:58:21 +0000 Subject: [PATCH] fix --- colossalai/shardformer/modeling/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f18d757ab..4e8f67407 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -538,6 +538,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size()