diff --git a/chat/web_demo.py b/chat/web_demo.py index cc5f07c..74432f4 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -125,7 +125,12 @@ def generate_interactive( stopping_criteria = model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) - logits_warper = model._get_logits_warper(generation_config) + + import transformers + if transformers.__version__ >= "4.42.0": + logits_warper = model._get_logits_warper(generation_config, device="cuda") + else: + logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None