fix(web): support new version transformers

pull/786/head
MCplayerFromPRC 2024-08-08 23:04:44 +08:00
parent 20af9072be
commit 7f148b3c2c
1 changed files with 2 additions and 1 deletions

View File

@ -23,6 +23,8 @@ from typing import Callable, List, Optional
import streamlit as st import streamlit as st
import torch import torch
from torch import nn from torch import nn
import transformers
from transformers.generation.utils import (LogitsProcessorList, from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList) StoppingCriteriaList)
from transformers.utils import logging from transformers.utils import logging
@ -126,7 +128,6 @@ def generate_interactive(
generation_config=generation_config, generation_config=generation_config,
stopping_criteria=stopping_criteria) stopping_criteria=stopping_criteria)
import transformers
if transformers.__version__ >= "4.42.0": if transformers.__version__ >= "4.42.0":
logits_warper = model._get_logits_warper(generation_config, device="cuda") logits_warper = model._get_logits_warper(generation_config, device="cuda")
else: else: