remove 4d attention mask

pull/5842/head
wangbluo 2024-06-19 09:28:08 +00:00
parent df612434c9
commit 52ea64824e
1 changed files with 0 additions and 8 deletions

View File

@ -116,14 +116,6 @@ class LlamaPipelineForwards:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache: