diff --git a/chat/web_demo.py b/chat/web_demo.py index 82873b8..2be2a7c 100644 --- a/chat/web_demo.py +++ b/chat/web_demo.py @@ -127,9 +127,10 @@ def generate_interactive( stopping_criteria = model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) - - if transformers.__version__ >= "4.42.0": - logits_warper = model._get_logits_warper(generation_config, device="cuda") + + if transformers.__version__ >= '4.42.0': + logits_warper = model._get_logits_warper(generation_config, + device='cuda') else: logits_warper = model._get_logits_warper(generation_config)