|
|
@ -8,6 +8,10 @@ import torch.utils.checkpoint
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from transformers.cache_utils import Cache
|
|
|
|
from transformers.cache_utils import Cache
|
|
|
|
|
|
|
|
from transformers.modeling_attn_mask_utils import (
|
|
|
|
|
|
|
|
_prepare_4d_causal_attention_mask,
|
|
|
|
|
|
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
|
|
|
|
)
|
|
|
|
from transformers.modeling_outputs import (
|
|
|
|
from transformers.modeling_outputs import (
|
|
|
|
BaseModelOutputWithPast,
|
|
|
|
BaseModelOutputWithPast,
|
|
|
|
CausalLMOutputWithPast,
|
|
|
|
CausalLMOutputWithPast,
|
|
|
@ -17,8 +21,6 @@ from transformers.models.llama.modeling_llama import (
|
|
|
|
LlamaForCausalLM,
|
|
|
|
LlamaForCausalLM,
|
|
|
|
LlamaForSequenceClassification,
|
|
|
|
LlamaForSequenceClassification,
|
|
|
|
LlamaModel,
|
|
|
|
LlamaModel,
|
|
|
|
_prepare_4d_causal_attention_mask,
|
|
|
|
|
|
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
|
|
|
|
apply_rotary_pos_emb,
|
|
|
|
apply_rotary_pos_emb,
|
|
|
|
repeat_kv,
|
|
|
|
repeat_kv,
|
|
|
|
)
|
|
|
|
)
|
|
|
|