From 193030f696f941b5ad36bb37cc4763f5f36dbc2e Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 21 Aug 2024 03:21:49 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/_operation.py | 6 ++-- colossalai/shardformer/modeling/llama.py | 32 +++++++--------------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bb403224f..f58a21df3 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1097,11 +1097,13 @@ def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8 return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) -def gather_sp_output(hidden_states, sp_group, sp_mode): +def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False): """ Gather the output of the last layer for cross entropy computation """ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication + ) return hidden_states diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 3f05d0428..f18d757ab 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -26,11 +26,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import AttnMaskType -from colossalai.shardformer.layer._operation import ( - all_to_all_comm, - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig @@ -235,12 +231,8 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + hidden_states = gather_sp_output( + hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication ) # add hidden states from the last decoder layer @@ -546,7 +538,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() @@ -565,6 +556,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -683,7 +675,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N past_seen_tokens = 0 seq_len = inputs_embeds.shape[1] - inputs_embeds.shape[0] + batch_size = inputs_embeds.shape[0] if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) @@ -697,7 +689,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N position_ids = cache_position.unsqueeze(0) if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len) attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, @@ -771,14 +763,10 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication - ) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + # Cases that don't support parallelizing cross entropy computation along sequence + if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: + hidden_states = gather_sp_output( + hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication ) # add hidden states from the last decoder layer