diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 6197be9d1..20870a3c2 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
         use_safetensors (bool): whether to use safetensors to save the checkpoint.
     """
     # Move all tensors in the state_dict to CPU before saving to avoid serialization issues
-    state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
+    state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict)
 
     if use_safetensors:
         assert is_safetensors_available(), "safetensors is not available."
diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py
index 7b25f3e74..013b0f061 100644
--- a/colossalai/inference/modeling/models/glide_llama.py
+++ b/colossalai/inference/modeling/models/glide_llama.py
@@ -6,11 +6,7 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
-from transformers.cache_utils import Cache, DynamicCache
-from transformers.modeling_attn_mask_utils import (
-    _prepare_4d_causal_attention_mask,
-    _prepare_4d_causal_attention_mask_for_sdpa,
-)
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
 from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 from transformers.models.llama.modeling_llama import (
     LlamaAttention,
@@ -137,6 +133,7 @@ def glide_llama_model_forward(
     output_attentions: Optional[bool] = None,
     output_hidden_states: Optional[bool] = None,
     return_dict: Optional[bool] = None,
+    cache_position: Optional[torch.LongTensor] = None,
 ) -> Union[Tuple, BaseModelOutputWithPast]:
     output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
     output_hidden_states = (
@@ -147,57 +144,43 @@ def glide_llama_model_forward(
     return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
     # retrieve input_ids and inputs_embeds
-    if input_ids is not None and inputs_embeds is not None:
-        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-    elif input_ids is not None:
-        batch_size, seq_length = input_ids.shape[:2]
-    elif inputs_embeds is not None:
-        batch_size, seq_length = inputs_embeds.shape[:2]
-    else:
-        raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-    past_key_values_length = 0
-    if use_cache:
-        use_legacy_cache = not isinstance(past_key_values, Cache)
-        if use_legacy_cache:
-            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
-        past_key_values_length = past_key_values.get_usable_length(seq_length)
-
-    if position_ids is None:
-        device = input_ids.device if input_ids is not None else inputs_embeds.device
-        position_ids = torch.arange(
-            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+    if (input_ids is None) ^ (inputs_embeds is not None):
+        raise ValueError(
+            "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
         )
-        position_ids = position_ids.unsqueeze(0)
+
+    if self.gradient_checkpointing and self.training and use_cache:
+        logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
+        use_cache = False
 
     if inputs_embeds is None:
         inputs_embeds = self.embed_tokens(input_ids)
 
-    if self._use_flash_attention_2:
-        # 2d mask is passed through the layers
-        attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
-    elif self._use_sdpa and not output_attentions:
-        # output_attentions=True can not be supported when using SDPA, and we fall back on
-        # the manual implementation that requires a 4D causal mask in all cases.
-        attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
-            attention_mask,
-            (batch_size, seq_length),
-            inputs_embeds,
-            past_key_values_length,
-        )
-    else:
-        # 4d mask is passed through the layers
-        attention_mask = _prepare_4d_causal_attention_mask(
-            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+    past_seen_tokens = 0
+    if use_cache:  # kept for BC (cache positions)
+        if not isinstance(past_key_values, StaticCache):
+            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+            past_seen_tokens = past_key_values.get_seq_length()
+
+    if cache_position is None:
+        if isinstance(past_key_values, StaticCache):
+            raise ValueError("cache_position is a required argument when using StaticCache.")
+        cache_position = torch.arange(
+            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
         )
 
+    if position_ids is None:
+        position_ids = cache_position.unsqueeze(0)
+
+    attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
     # embed positions
     hidden_states = inputs_embeds
 
     # decoder layers
     all_hidden_states = () if output_hidden_states else None
     all_self_attns = () if output_attentions else None
-    next_decoder_cache = () if use_cache else None
+    next_decoder_cache = None
 
     for decoder_layer in self.layers:
         if output_hidden_states:
@@ -212,6 +195,7 @@ def glide_llama_model_forward(
             past_key_value=past_key_values,
             output_attentions=output_attentions,
             use_cache=use_cache,
+            cache_position=cache_position,
         )
 
         hidden_states = layer_outputs[0]
@@ -230,7 +214,9 @@ def glide_llama_model_forward(
 
     next_cache = None
     if use_cache:
-        next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+        next_cache = (
+            next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+        )
     if not return_dict:
         return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
     return BaseModelOutputWithPast(
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index c49458dbd..aa75bab11 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -738,7 +738,10 @@ class GPT2PipelineForwards:
             sequence_lengths = -1
         else:
             if input_ids is not None:
-                sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+                sequence_lengths = sequence_lengths % input_ids.shape[-1]
+                sequence_lengths = sequence_lengths.to(logits.device)
             else:
                 sequence_lengths = -1
                 logger.warning_once(
diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py
index 4f4cec8bc..facd2fcaf 100644
--- a/colossalai/shardformer/modeling/gptj.py
+++ b/colossalai/shardformer/modeling/gptj.py
@@ -32,6 +32,7 @@ def _get_attention_mask(
     hidden_states: torch.Tensor,
     past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
     attention_mask: Optional[torch.FloatTensor],
+    use_flash_attention_2: bool = False,
 ) -> Optional[Union[torch.Tensor, dict]]:
     batch_size, seq_len = hidden_states.shape[:2]
     past_key_values_length = 0
@@ -47,7 +48,7 @@ def _get_attention_mask(
             attention_mask,
             is_causal=True,
         )
-    elif attention_mask is not None:
+    elif use_flash_attention_2 and attention_mask is not None:
         if batch_size <= 0:
             raise ValueError("batch_size has to be defined and > 0")
         attention_mask = attention_mask.view(batch_size, -1)
@@ -162,7 +163,9 @@ class GPTJPipelineForwards:
 
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+        attention_mask = _get_attention_mask(
+            self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
+        )
 
         if self.gradient_checkpointing and self.training:
             if use_cache:
@@ -419,7 +422,10 @@ class GPTJPipelineForwards:
             sequence_lengths = -1
         else:
             if input_ids is not None:
-                sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+                sequence_lengths = sequence_lengths % input_ids.shape[-1]
+                sequence_lengths = sequence_lengths.to(logits.device)
             else:
                 sequence_lengths = -1
                 logger.warning_once(
@@ -712,7 +718,9 @@ def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
 
         hidden_states = self.drop(hidden_states)
 
-        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+        attention_mask = _get_attention_mask(
+            self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
+        )
 
         output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
 
@@ -886,7 +894,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
         hidden_states = self.drop(hidden_states)
 
         output_shape = input_shape + (hidden_states.size(-1),)
-        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+        attention_mask = _get_attention_mask(
+            self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
+        )
 
         if self.gradient_checkpointing and self.training:
             if use_cache:
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 01d10c8dc..f47be48ee 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -7,11 +7,7 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.cache_utils import Cache
-from transformers.modeling_attn_mask_utils import (
-    _prepare_4d_causal_attention_mask,
-    _prepare_4d_causal_attention_mask_for_sdpa,
-)
+from transformers.cache_utils import Cache, DynamicCache
 from transformers.modeling_outputs import (
     BaseModelOutputWithPast,
     CausalLMOutputWithPast,
@@ -21,7 +17,7 @@ from transformers.models.llama.modeling_llama import (
     LlamaForCausalLM,
     LlamaForSequenceClassification,
     LlamaModel,
-    apply_rotary_pos_emb,
+    StaticCache,
     repeat_kv,
 )
 from transformers.utils import logging
@@ -55,6 +51,7 @@ class LlamaPipelineForwards:
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
@@ -67,6 +64,11 @@ class LlamaPipelineForwards:
             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         )
         use_cache = use_cache if use_cache is not None else self.config.use_cache
+        if use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..."
+            )
+            use_cache = False
 
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
@@ -83,14 +85,24 @@ class LlamaPipelineForwards:
             device = input_ids.device if input_ids is not None else inputs_embeds.device
             if inputs_embeds is None:
                 inputs_embeds = self.embed_tokens(input_ids)
+
             hidden_states = inputs_embeds
         else:
             input_shape = hidden_states.shape[:-1]
             batch_size, seq_length = input_shape
             device = hidden_states.device
 
-        seq_length_with_past = seq_length
-        past_key_values_length = 0
+        past_seen_tokens = 0
+        if use_cache:  # kept for BC (cache positions)
+            if not isinstance(past_key_values, StaticCache):
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                past_seen_tokens = past_key_values.get_seq_length()
+        if cache_position is None:
+            if isinstance(past_key_values, StaticCache):
+                raise ValueError("cache_position is a required argument when using StaticCache.")
+            cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
+
+        seq_length_with_past = seq_length + past_seen_tokens
 
         # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
         if output_attentions:
@@ -103,18 +115,8 @@ class LlamaPipelineForwards:
             logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
             use_cache = False
 
-        if past_key_values is not None:
-            past_key_values_length = past_key_values[0][0].shape[2]
-            seq_length_with_past = seq_length_with_past + past_key_values_length
-
         if position_ids is None:
-            position_ids = torch.arange(
-                past_key_values_length,
-                seq_length + past_key_values_length,
-                dtype=torch.long,
-                device=device,
-            )
-            position_ids = position_ids.unsqueeze(0)
+            position_ids = cache_position.unsqueeze(0)
 
         # embed positions, for the first stage, hidden_states is the input embeddings,
         # for the other stages, hidden_states is the output of the previous stage
@@ -129,28 +131,9 @@ class LlamaPipelineForwards:
                 is_causal=True,
             )
         else:
-            if self._use_flash_attention_2:
-                # 2d mask is passed through the layers
-                attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
-            elif self._use_sdpa and not output_attentions:
-                # output_attentions=True can not be supported when using SDPA, and we fall back on
-                # the manual implementation that requires a 4D causal mask in all cases.
-                attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
-                    attention_mask,
-                    (batch_size, seq_length),
-                    inputs_embeds,
-                    past_key_values_length,
-                )
-            else:
-                # 4d mask is passed through the layers
-                attention_mask = _prepare_4d_causal_attention_mask(
-                    attention_mask,
-                    (batch_size, seq_length),
-                    hidden_states,
-                    past_key_values_length,
-                )
+            attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
 
-        if self.gradient_checkpointing and self.training:
+        if self.gradient_checkpointing and self.training and use_cache:
             if use_cache:
                 logger.warning_once(
                     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -190,6 +173,7 @@ class LlamaPipelineForwards:
                     past_key_values,
                     output_attentions,
                     use_cache,
+                    cache_position,
                 )
             else:
                 layer_outputs = decoder_layer(
@@ -199,6 +183,7 @@ class LlamaPipelineForwards:
                     past_key_value=past_key_values,
                     output_attentions=output_attentions,
                     use_cache=use_cache,
+                    cache_position=cache_position,
                 )
 
             hidden_states = layer_outputs[0]
@@ -249,6 +234,7 @@ class LlamaPipelineForwards:
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
@@ -306,6 +292,7 @@ class LlamaPipelineForwards:
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
+            cache_position=cache_position,
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
@@ -368,6 +355,7 @@ class LlamaPipelineForwards:
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
@@ -401,6 +389,7 @@ class LlamaPipelineForwards:
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
+            cache_position=cache_position,
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
@@ -486,6 +475,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
         past_key_value: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         if "padding_mask" in kwargs:
@@ -520,13 +510,14 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
                     "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                     "with a layer index."
                 )
+
             kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
         if past_key_value is not None:
-            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
         key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -562,6 +553,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Union[Tuple, BaseModelOutputWithPast]:
         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         output_hidden_states = (
@@ -572,41 +564,40 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
         # retrieve input_ids and inputs_embeds
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-        elif input_ids is not None:
-            batch_size, seq_length = input_ids.shape
-        elif inputs_embeds is not None:
-            batch_size, seq_length, _ = inputs_embeds.shape
-        else:
-            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
-        seq_length_with_past = seq_length
-        past_key_values_length = 0
-
-        if past_key_values is not None:
-            past_key_values_length = past_key_values[0][0].shape[2]
-            seq_length_with_past = seq_length_with_past + past_key_values_length
-
-        if position_ids is None:
-            device = input_ids.device if input_ids is not None else inputs_embeds.device
-            position_ids = torch.arange(
-                past_key_values_length,
-                seq_length + past_key_values_length,
-                dtype=torch.long,
-                device=device,
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError(
+                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
             )
-            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+            )
+            use_cache = False
 
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
+
+        past_seen_tokens = 0
+        if use_cache:  # kept for BC (cache positions)
+            if not isinstance(past_key_values, StaticCache):
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                past_seen_tokens = past_key_values.get_seq_length()
+        if cache_position is None:
+            if isinstance(past_key_values, StaticCache):
+                raise ValueError("cache_position is a required argument when using StaticCache.")
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
         # embed positions
         hidden_states = inputs_embeds
 
         # in this case, attention_mask is a dict rather than a tensor
-        mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+        mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens)
         attention_mask = ColoAttention.prepare_attn_kwargs(
             mask_shape,
             hidden_states.dtype,
@@ -625,43 +616,38 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
         # decoder layers
         all_hidden_states = () if output_hidden_states else None
         all_self_attns = () if output_attentions else None
-        next_decoder_cache = () if use_cache else None
+        next_decoder_cache = None
 
-        for idx, decoder_layer in enumerate(self.layers):
+        for decoder_layer in self.layers:
             if output_hidden_states:
                 all_hidden_states += (hidden_states,)
 
-            past_key_value = past_key_values[idx] if past_key_values is not None else None
-
             if self.gradient_checkpointing and self.training:
-
-                def create_custom_forward(module):
-                    def custom_forward(*inputs):
-                        # None for past_key_value
-                        return module(*inputs, past_key_value, output_attentions)
-
-                    return custom_forward
-
-                layer_outputs = torch.utils.checkpoint.checkpoint(
-                    create_custom_forward(decoder_layer),
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
                     hidden_states,
                     attention_mask,
                     position_ids,
+                    past_key_values,
+                    output_attentions,
+                    use_cache,
+                    cache_position,
                 )
             else:
                 layer_outputs = decoder_layer(
                     hidden_states,
                     attention_mask=attention_mask,
                     position_ids=position_ids,
-                    past_key_value=past_key_value,
+                    past_key_value=past_key_values,
                     output_attentions=output_attentions,
                     use_cache=use_cache,
+                    cache_position=cache_position,
                 )
 
             hidden_states = layer_outputs[0]
 
             if use_cache:
-                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
             if output_attentions:
                 all_self_attns += (layer_outputs[1],)
@@ -672,7 +658,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
         if output_hidden_states:
             all_hidden_states += (hidden_states,)
 
-        next_cache = next_decoder_cache if use_cache else None
+        next_cache = None
+        if use_cache:
+            next_cache = (
+                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+            )
         if not return_dict:
             return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
         return BaseModelOutputWithPast(
@@ -700,6 +690,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Union[Tuple, CausalLMOutputWithPast]:
         r"""
         Args:
@@ -744,6 +735,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
+            cache_position=cache_position,
         )
 
         hidden_states = outputs[0]
@@ -789,6 +781,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
 
 
 def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
+    from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -797,6 +791,7 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         bsz, q_len, _ = hidden_states.size()
         # sp: modify sp_len when sequence parallel mode is ring
@@ -835,18 +830,14 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
-        kv_seq_len = key_states.shape[-2]
-        if past_key_value is not None:
-            kv_seq_len += past_key_value[0].shape[-2]
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        past_key_value = getattr(self, "past_key_value", past_key_value)
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
         if past_key_value is not None:
-            # reuse k, v, self_attention
-            key_states = torch.cat([past_key_value[0], key_states], dim=2)
-            value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
-        past_key_value = (key_states, value_states) if use_cache else None
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
         # repeat k/v heads if n_kv_heads < n_heads
         key_states = repeat_kv(key_states, self.num_key_value_groups)
@@ -854,18 +845,9 @@ def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
 
         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
-        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
-            raise ValueError(
-                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
-                f" {attn_weights.size()}"
-            )
-
-        if attention_mask is not None:
-            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
-                )
-            attn_weights = attn_weights + attention_mask
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
 
         # upcast attention to fp32
         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -903,7 +885,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
     logger = logging.get_logger(__name__)
 
     def forward(
-        self,
+        self: LlamaModel,
         input_ids: torch.LongTensor = None,
         attention_mask: Optional[torch.Tensor] = None,
         position_ids: Optional[torch.LongTensor] = None,
@@ -913,6 +895,7 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Union[Tuple, BaseModelOutputWithPast]:
         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         output_hidden_states = (
@@ -924,56 +907,13 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
 
         # retrieve input_ids and inputs_embeds
         if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-        elif input_ids is not None:
-            batch_size, seq_length = input_ids.shape
-        elif inputs_embeds is not None:
-            batch_size, seq_length, _ = inputs_embeds.shape
-        else:
-            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
-        seq_length_with_past = seq_length
-        past_key_values_length = 0
-
-        if past_key_values is not None:
-            past_key_values_length = past_key_values[0][0].shape[2]
-            # modify past_key_values_length when using sequence parallel
-            past_key_values_length *= sp_size
-            seq_length_with_past = seq_length_with_past + past_key_values_length
-
-        if position_ids is None:
-            device = input_ids.device if input_ids is not None else inputs_embeds.device
-            position_ids = torch.arange(
-                past_key_values_length,
-                seq_length + past_key_values_length,
-                dtype=torch.long,
-                device=device,
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one"
             )
-            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
 
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
 
-        if sp_mode in ["ring", "split_gather"]:
-            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
-        elif sp_mode == "all_to_all":
-            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
-
-        if attention_mask is None:
-            attention_mask = torch.ones(
-                (batch_size, seq_length_with_past),
-                dtype=torch.bool,
-                device=inputs_embeds.device,
-            )
-
-        attention_mask = _prepare_4d_causal_attention_mask(
-            attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
-        )
-
-        hidden_states = inputs_embeds
-
         if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
             if use_cache:
                 logger.warning_once(
@@ -981,6 +921,29 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
                 )
                 use_cache = False
 
+        past_seen_tokens = 0
+        if use_cache:  # kept for BC (cache positions)
+            if not isinstance(past_key_values, StaticCache):
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                past_seen_tokens = past_key_values.get_seq_length()
+        if cache_position is None:
+            if isinstance(past_key_values, StaticCache):
+                raise ValueError("cache_position is a required argument when using StaticCache.")
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+        if sp_mode in ["ring", "split_gather"]:
+            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+        elif sp_mode == "all_to_all":
+            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+
+        hidden_states = inputs_embeds
+
         # decoder layers
         all_hidden_states = () if output_hidden_states else None
         all_self_attns = () if output_attentions else None
@@ -990,14 +953,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
             if output_hidden_states:
                 all_hidden_states += (hidden_states,)
 
-            past_key_value = past_key_values[idx] if past_key_values is not None else None
-
             if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
 
                 def create_custom_forward(module):
                     def custom_forward(*inputs):
                         # None for past_key_value
-                        return module(*inputs, past_key_value, output_attentions)
+                        return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions)
 
                     return custom_forward
 
@@ -1013,15 +974,20 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
                     hidden_states,
                     attention_mask=attention_mask,
                     position_ids=position_ids,
-                    past_key_value=past_key_value,
+                    past_key_value=past_key_values,
                     output_attentions=output_attentions,
                     use_cache=use_cache,
+                    cache_position=cache_position,
                 )
 
             hidden_states = layer_outputs[0]
 
             if use_cache:
-                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+                next_decoder_cache = (
+                    next_decoder_cache.to_legacy_cache()
+                    if isinstance(next_decoder_cache, Cache)
+                    else next_decoder_cache
+                )
 
             if output_attentions:
                 all_self_attns += (layer_outputs[1],)
diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py
index 5f96ebe3d..310c2d8e2 100644
--- a/colossalai/shardformer/modeling/mistral.py
+++ b/colossalai/shardformer/modeling/mistral.py
@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple, Union
 import torch
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 from transformers.cache_utils import Cache, DynamicCache
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.modeling_attn_mask_utils import (
+    _prepare_4d_causal_attention_mask,
+    _prepare_4d_causal_attention_mask_for_sdpa,
+)
 from transformers.modeling_outputs import (
     BaseModelOutputWithPast,
     CausalLMOutputWithPast,
@@ -77,7 +80,7 @@ class MistralForwards:
         else:
             position_ids = position_ids.view(-1, seq_length).long()
 
-        if attention_mask is not None and self._use_flash_attention_2 and use_cache:
+        if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
             is_padding_right = attention_mask[:, -1].sum().item() != batch_size
             if is_padding_right:
                 raise ValueError(
@@ -97,9 +100,18 @@ class MistralForwards:
                 is_causal=True,
             )
         else:
-            if self._use_flash_attention_2:
+            if self._attn_implementation == "flash_attention_2":
                 # 2d mask is passed through the layers
                 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+            elif self._attn_implementation == "sdpa" and not output_attentions:
+                # output_attentions=True can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                    attention_mask,
+                    (batch_size, seq_length),
+                    inputs_embeds,
+                    past_key_values_length,
+                )
             else:
                 # 4d mask is passed through the layers
                 attention_mask = _prepare_4d_causal_attention_mask(
@@ -462,7 +474,7 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
 
-        if attention_mask is not None and self._use_flash_attention_2 and use_cache:
+        if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
             is_padding_right = attention_mask[:, -1].sum().item() != batch_size
             if is_padding_right:
                 raise ValueError(
@@ -481,9 +493,18 @@ def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
                 is_causal=True,
             )
         else:
-            if self._use_flash_attention_2:
+            if self._attn_implementation == "flash_attention_2":
                 # 2d mask is passed through the layers
                 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+            elif self._attn_implementation == "sdpa" and not output_attentions:
+                # output_attentions=True can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                    attention_mask,
+                    (batch_size, seq_length),
+                    inputs_embeds,
+                    past_key_values_length,
+                )
             else:
                 # 4d mask is passed through the layers
                 attention_mask = _prepare_4d_causal_attention_mask(
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index 6d7df963a..cf925983b 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (
     SequenceClassifierOutput,
 )
 from transformers.models.whisper.modeling_whisper import (
+    _HIDDEN_STATES_START_POSITION,
     WhisperDecoder,
     WhisperEncoder,
     WhisperForAudioClassification,
@@ -166,6 +167,7 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
         cross_attn_head_mask=None,
         past_key_values=None,
         inputs_embeds=None,
+        position_ids=None,
         use_cache=None,
         output_attentions=None,
         output_hidden_states=None,
@@ -199,9 +201,13 @@ def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
 
         # embed positions
         if input_ids is not None:
-            positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+            positions = self.embed_positions(
+                input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
+            )
         else:
-            positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+            positions = self.embed_positions(
+                inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
+            )
 
         hidden_states = inputs_embeds + positions
         hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -599,6 +605,7 @@ class WhisperPipelineForwards:
         cross_attn_head_mask=None,
         past_key_values=None,
         inputs_embeds=None,
+        position_ids=None,
         use_cache=None,
         output_attentions=None,
         output_hidden_states=None,
@@ -716,9 +723,13 @@ class WhisperPipelineForwards:
 
             # embed positions
             if input_ids is not None:
-                positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+                positions = self.embed_positions(
+                    input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
+                )
             else:
-                positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+                positions = self.embed_positions(
+                    inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
+                )
 
             hidden_states = inputs_embeds + positions
             hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -841,6 +852,7 @@ class WhisperPipelineForwards:
         encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
         past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
         decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+        decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
         use_cache: Optional[bool] = None,
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
@@ -944,6 +956,7 @@ class WhisperPipelineForwards:
             cross_attn_head_mask=cross_attn_head_mask,
             past_key_values=past_key_values,
             inputs_embeds=decoder_inputs_embeds,
+            position_ids=decoder_position_ids,
             use_cache=use_cache,
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
@@ -986,6 +999,7 @@ class WhisperPipelineForwards:
         encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
         past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
         decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+        decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
         labels: Optional[torch.LongTensor] = None,
         use_cache: Optional[bool] = None,
         output_attentions: Optional[bool] = None,
@@ -1048,6 +1062,7 @@ class WhisperPipelineForwards:
             cross_attn_head_mask=cross_attn_head_mask,
             past_key_values=past_key_values,
             decoder_inputs_embeds=decoder_inputs_embeds,
+            decoder_position_ids=decoder_position_ids,
             use_cache=use_cache,
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
@@ -1118,6 +1133,12 @@ class WhisperPipelineForwards:
         output_hidden_states = (
             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         )
+
+        if self.config.use_weighted_layer_sum:
+            output_hidden_states = True
+        elif output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
         # audio_classification only holds encoder
@@ -1138,7 +1159,8 @@ class WhisperPipelineForwards:
             return encoder_outputs
 
         if self.config.use_weighted_layer_sum:
-            hidden_states = torch.stack(encoder_outputs, dim=1)
+            hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
             norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
             hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
         else:
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index 3315eb1e9..c394d911e 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -34,15 +34,11 @@ class GPTJPolicy(Policy):
         return self.model
 
     def module_policy(self):
-        from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
-
-        ATTN_IMPLEMENTATION = {
-            "eager": GPTJAttention,
-        }
+        from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel
 
         policy = {}
 
-        attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+        attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement]
 
         embedding_cls = None
         if self.shard_config.enable_tensor_parallelism:
diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py
index 621982f29..c5a0277a5 100644
--- a/colossalai/shardformer/policies/mistral.py
+++ b/colossalai/shardformer/policies/mistral.py
@@ -42,11 +42,13 @@ class MistralPolicy(Policy):
             MistralDecoderLayer,
             MistralFlashAttention2,
             MistralModel,
+            MistralSdpaAttention,
         )
 
         ATTN_IMPLEMENTATION = {
             "eager": MistralAttention,
             "flash_attention_2": MistralFlashAttention2,
+            "sdpa": MistralSdpaAttention,
         }
 
         policy = {}
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index fa88501ef..27bbc3769 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -16,7 +16,7 @@ ray
 sentencepiece
 google
 protobuf
-transformers>=4.36.2,<4.40.0
+transformers==4.39.3
 peft>=0.7.1
 bitsandbytes>=0.39.0
 rpyc==6.0.0
diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
index 8237384c0..57a82647d 100644
--- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
+++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
@@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
     torch.manual_seed(10)
     TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
     # our crafted op equals to Transformers
-    x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
-    x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
+    x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
+    x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
+
+    position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
+
     emb = LlamaRotaryEmbedding(D)
-    cos, sin = emb(x0, TOTAL_TOKENS)
+
+    cos, sin = emb(x0, position_ids)
+    embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
+    cos = cos.reshape((TOTAL_TOKENS, -1))
+    sin = sin.reshape((TOTAL_TOKENS, -1))
     cos_2 = cos[:, : D // 2]
     sin_2 = sin[:, : D // 2]
-    position_ids = torch.arange(TOTAL_TOKENS)
-    embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
-    embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
+    x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
+    embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
+    embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
     assert torch.allclose(embd_x0, embd_stimulated_x)
 
     # create data
diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
index 570093693..78b7ba81c 100644
--- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
+++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
@@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin):
 def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
     TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
     # our crafted op equals to Transformers
-    x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
-    x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
+    x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
+    x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
     emb = LlamaRotaryEmbedding(D)
-    cos, sin = emb(x0, TOTAL_TOKENS)
+    position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
+    cos, sin = emb(x0, position_ids)
+    embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
+    cos = cos.reshape((TOTAL_TOKENS, -1))
+    sin = sin.reshape((TOTAL_TOKENS, -1))
     cos_2 = cos[:, :32]
     sin_2 = sin[:, :32]
-    position_ids = torch.arange(TOTAL_TOKENS)
-    embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
-    embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
+    x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
+    embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
+    embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
     assert torch.allclose(embd_x0, embd_stimulated_x)
 
     # create data