mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix llama error when transformers upgraded. (#5055)
* fix-llama * Update llama.pypull/5060/head
parent
3e02154710
commit
97cd0cd559
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
@ -13,6 +13,11 @@ from transformers.utils import logging
|
|||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
LATEST_VERSION = True
|
||||
except ImportError:
|
||||
LATEST_VERSION = False
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
"""
|
||||
|
@ -97,9 +102,14 @@ class LlamaPipelineForwards:
|
|||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
if LATEST_VERSION:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
|
|
Loading…
Reference in New Issue