fix(web): support new version transformers

pull/786/head
MCplayerFromPRC 2024-08-08 22:45:45 +08:00
parent ef74b41aca
commit 20af9072be
1 changed files with 6 additions and 1 deletions

View File

@ -125,6 +125,11 @@ def generate_interactive(
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
import transformers
if transformers.__version__ >= "4.42.0":
logits_warper = model._get_logits_warper(generation_config, device="cuda")
else:
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)