From e57ca246d991906ed6f3ba7269dde088830bbee2 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 11 Dec 2023 13:53:48 +0800 Subject: [PATCH] importerror --- tools/transformers/internlm_model/modeling_internlm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/transformers/internlm_model/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py index e2d52ed..9ea7f17 100644 --- a/tools/transformers/internlm_model/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -28,7 +28,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN -from transformers.generation.streamers import BaseStreamer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,6 +41,11 @@ from transformers.utils import ( replace_return_docstrings, ) +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + from .configuration_internlm import InternLMConfig logger = logging.get_logger(__name__) @@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module): base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. device (Any, optional): Running device. Defaults to None. """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))