[Fix] Support new version of transformers for web_demo (#786)

pull/802/head
Chang Cheng 2024-08-08 23:22:31 +08:00 committed by GitHub
parent ef74b41aca
commit 6920fa080d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 1 deletions

View File

@ -23,6 +23,8 @@ from typing import Callable, List, Optional
import streamlit as st import streamlit as st
import torch import torch
from torch import nn from torch import nn
import transformers
from transformers.generation.utils import (LogitsProcessorList, from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList) StoppingCriteriaList)
from transformers.utils import logging from transformers.utils import logging
@ -125,7 +127,12 @@ def generate_interactive(
stopping_criteria = model._get_stopping_criteria( stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, generation_config=generation_config,
stopping_criteria=stopping_criteria) stopping_criteria=stopping_criteria)
logits_warper = model._get_logits_warper(generation_config)
if transformers.__version__ >= '4.42.0':
logits_warper = model._get_logits_warper(generation_config,
device='cuda')
else:
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None scores = None