importerror

pull/532/head
lijiaxing 2023-12-11 13:53:48 +08:00
parent 472671688f
commit e57ca246d9
1 changed files with 6 additions and 1 deletions

View File

@ -28,7 +28,6 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -42,6 +41,11 @@ from transformers.utils import (
replace_return_docstrings,
)
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__)
@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None.
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))