[shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py
pull/5060/head
flybird11111 2023-11-16 21:34:04 +08:00 committed by GitHub
parent 3e02154710
commit 97cd0cd559
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 4 deletions

View File

@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -13,6 +13,11 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
@ -97,9 +102,14 @@ class LlamaPipelineForwards:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache: