Browse Source

remove 4d attention mask

pull/5842/head
wangbluo 5 months ago
parent
commit
52ea64824e
  1. 8
      colossalai/shardformer/modeling/llama.py

8
colossalai/shardformer/modeling/llama.py

@ -116,14 +116,6 @@ class LlamaPipelineForwards:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache:

Loading…
Cancel
Save