mirror of https://github.com/THUDM/ChatGLM-6B
fixed web_demo to fit chat_stream
parent
17ecc57266
commit
09c619f5a3
13
web_demo.py
13
web_demo.py
|
@ -12,12 +12,19 @@ MAX_BOXES = MAX_TURNS * 2
|
|||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
if history is None:
|
||||
history = []
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
flag = True
|
||||
response = ''
|
||||
for delta, seq, history in model.chat_stream(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
updates = []
|
||||
for query, response in history:
|
||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||
response += delta
|
||||
if flag:
|
||||
updates.append(gr.update(visible=True, value="用户:" + input))
|
||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
||||
flag = False
|
||||
else:
|
||||
updates[-2]=gr.update(visible=True, value="用户:" + input)
|
||||
updates[-1]=gr.update(visible=True, value="ChatGLM-6B:" + response)
|
||||
if len(updates) < MAX_BOXES:
|
||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
||||
yield [history] + updates
|
||||
|
|
Loading…
Reference in New Issue