[shardformer] fix import (#5788)

pull/5790/head
Hongxin Liu 6 months ago committed by GitHub
parent 5ead00ffc5
commit 73e88a5553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,6 +8,10 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaModel, LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )

@ -9,13 +9,15 @@ from transformers.modeling_outputs import (
) )
try: try:
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.models.qwen2.modeling_qwen2 import ( from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention, Qwen2Attention,
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2Model, Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )

Loading…
Cancel
Save