From 6920fa080dc9673cf615766f8fe91bc51902f61a Mon Sep 17 00:00:00 2001 From: Chang Cheng <1953414760@qq.com> Date: Thu, 8 Aug 2024 23:22:31 +0800 Subject: [PATCH] [Fix] Support new version of transformers for web_demo (#786) --- chat/web_demo.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/chat/web_demo.py b/chat/web_demo.py index cc5f07c..2be2a7c 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -23,6 +23,8 @@ from typing import Callable, List, Optional import streamlit as st import torch from torch import nn + +import transformers from transformers.generation.utils import (LogitsProcessorList, StoppingCriteriaList) from transformers.utils import logging @@ -125,7 +127,12 @@ def generate_interactive( stopping_criteria = model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) - logits_warper = model._get_logits_warper(generation_config) + + if transformers.__version__ >= '4.42.0': + logits_warper = model._get_logits_warper(generation_config, + device='cuda') + else: + logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None