diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb421c92b..efcf37987 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -116,14 +116,6 @@ class LlamaPipelineForwards: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) if self.gradient_checkpointing and self.training: if use_cache: