mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix import (#5788)
parent
5ead00ffc5
commit
73e88a5553
|
@ -8,6 +8,10 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
|
|
@ -9,13 +9,15 @@ from transformers.modeling_outputs import (
|
|||
)
|
||||
|
||||
try:
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.qwen2.modeling_qwen2 import (
|
||||
Qwen2Attention,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2Model,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue