mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix import (#5788)
parent
5ead00ffc5
commit
73e88a5553
|
@ -8,6 +8,10 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,13 +9,15 @@ from transformers.modeling_outputs import (
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
from transformers.models.qwen2.modeling_qwen2 import (
|
from transformers.models.qwen2.modeling_qwen2 import (
|
||||||
Qwen2Attention,
|
Qwen2Attention,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue