fix(web): support new version transformers

pull/786/head
MCplayerFromPRC 2024-08-08 23:19:17 +08:00
parent 7f148b3c2c
commit 1de230c055
1 changed files with 4 additions and 3 deletions

View File

@ -127,9 +127,10 @@ def generate_interactive(
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
if transformers.__version__ >= "4.42.0":
logits_warper = model._get_logits_warper(generation_config, device="cuda")
if transformers.__version__ >= '4.42.0':
logits_warper = model._get_logits_warper(generation_config,
device='cuda')
else:
logits_warper = model._get_logits_warper(generation_config)