[shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py
pull/5060/head
flybird11111 2023-11-16 21:34:04 +08:00 committed by GitHub
parent 3e02154710
commit 97cd0cd559
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 4 deletions

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -13,6 +13,11 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards: class LlamaPipelineForwards:
""" """
@ -97,9 +102,14 @@ class LlamaPipelineForwards:
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
) )
attention_mask = self._prepare_decoder_attention_mask( if LATEST_VERSION:
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = _prepare_4d_causal_attention_mask(
) attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache: