From 73e88a5553235897dc92f10fa9704b531f1e2959 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 6 Jun 2024 19:09:50 +0800 Subject: [PATCH] [shardformer] fix import (#5788) --- colossalai/shardformer/modeling/llama.py | 6 ++++-- colossalai/shardformer/modeling/qwen2.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2b30074a5..01d10c8dc 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -8,6 +8,10 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 8f8ab25a5..e0aa5fba4 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -9,13 +9,15 @@ from transformers.modeling_outputs import ( ) try: + from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, + ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2Model, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, )