diff --git a/chat/web_demo.py b/chat/web_demo.py index 74432f4..82873b8 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -23,6 +23,8 @@ from typing import Callable, List, Optional import streamlit as st import torch from torch import nn + +import transformers from transformers.generation.utils import (LogitsProcessorList, StoppingCriteriaList) from transformers.utils import logging @@ -126,7 +128,6 @@ def generate_interactive( generation_config=generation_config, stopping_criteria=stopping_criteria) - import transformers if transformers.__version__ >= "4.42.0": logits_warper = model._get_logits_warper(generation_config, device="cuda") else: