From 9664b1bc190c57518fd76f4a3740feea3dc38ffd Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 29 Jul 2024 13:58:27 +0800 Subject: [PATCH] [shardformer] hotfix attn mask (#5945) --- colossalai/shardformer/modeling/command.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- colossalai/shardformer/modeling/mistral.py | 2 +- colossalai/shardformer/modeling/qwen2.py | 8 ++++++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 759c8d7b8..5b36fc7db 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -116,7 +116,7 @@ class CommandPipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 54ff8e321..9ffbca517 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -643,7 +643,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: - mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 82e8ef5f9..ec1a8a00a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -91,7 +91,7 @@ class MistralForwards: if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length) + mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 55822b150..538e96c32 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -136,7 +136,7 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, @@ -651,6 +651,10 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No seq_length_with_past = seq_length past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -668,7 +672,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + mask_shape = (batch_size, 1, seq_length, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype,