mirror of https://github.com/InternLM/InternLM
importerror
parent
472671688f
commit
e57ca246d9
|
@ -28,7 +28,6 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.generation.streamers import BaseStreamer
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
@ -42,6 +41,11 @@ from transformers.utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.generation.streamers import BaseStreamer
|
||||||
|
except: # noqa # pylint: disable=bare-except
|
||||||
|
BaseStreamer = None
|
||||||
|
|
||||||
from .configuration_internlm import InternLMConfig
|
from .configuration_internlm import InternLMConfig
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
||||||
device (Any, optional): Running device. Defaults to None.
|
device (Any, optional): Running device. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
|
|
Loading…
Reference in New Issue