[Fix] Support new version of transformers for web_demo (#786)

pull/802/head
Chang Cheng 2024-08-08 23:22:31 +08:00 committed by GitHub
parent ef74b41aca
commit 6920fa080d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 1 deletions

View File

@ -23,6 +23,8 @@ from typing import Callable, List, Optional
import streamlit as st
import torch
from torch import nn
import transformers
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from transformers.utils import logging
@ -125,7 +127,12 @@ def generate_interactive(
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = model._get_logits_warper(generation_config)
if transformers.__version__ >= '4.42.0':
logits_warper = model._get_logits_warper(generation_config,
device='cuda')
else:
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None