mirror of https://github.com/InternLM/InternLM
[Fix] Support new version of transformers for web_demo (#786)
parent
ef74b41aca
commit
6920fa080d
|
@ -23,6 +23,8 @@ from typing import Callable, List, Optional
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers.generation.utils import (LogitsProcessorList,
|
from transformers.generation.utils import (LogitsProcessorList,
|
||||||
StoppingCriteriaList)
|
StoppingCriteriaList)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
@ -125,7 +127,12 @@ def generate_interactive(
|
||||||
stopping_criteria = model._get_stopping_criteria(
|
stopping_criteria = model._get_stopping_criteria(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
stopping_criteria=stopping_criteria)
|
stopping_criteria=stopping_criteria)
|
||||||
logits_warper = model._get_logits_warper(generation_config)
|
|
||||||
|
if transformers.__version__ >= '4.42.0':
|
||||||
|
logits_warper = model._get_logits_warper(generation_config,
|
||||||
|
device='cuda')
|
||||||
|
else:
|
||||||
|
logits_warper = model._get_logits_warper(generation_config)
|
||||||
|
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
scores = None
|
scores = None
|
||||||
|
|
Loading…
Reference in New Issue