From 20af9072be2c43de3dd0d61dff51e00b436dd663 Mon Sep 17 00:00:00 2001 From: MCplayerFromPRC <1953414760@qq.com> Date: Thu, 8 Aug 2024 22:45:45 +0800 Subject: [PATCH] fix(web): support new version transformers --- chat/web_demo.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chat/web_demo.py b/chat/web_demo.py index cc5f07c..74432f4 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -125,7 +125,12 @@ def generate_interactive( stopping_criteria = model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) - logits_warper = model._get_logits_warper(generation_config) + + import transformers + if transformers.__version__ >= "4.42.0": + logits_warper = model._get_logits_warper(generation_config, device="cuda") + else: + logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None