From 97cd0cd559f61de6dc1b4fdf945787f93deed330 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 16 Nov 2023 21:34:04 +0800 Subject: [PATCH] [shardformer] fix llama error when transformers upgraded. (#5055) * fix-llama * Update llama.py --- colossalai/shardformer/modeling/llama.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0f911be48..4bfef4529 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -13,6 +13,11 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +try: + from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True +except ImportError: + LATEST_VERSION = False class LlamaPipelineForwards: """ @@ -97,9 +102,14 @@ class LlamaPipelineForwards: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + if LATEST_VERSION: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: