diff --git a/web_demo.py b/web_demo.py index 39709a7..02cd0aa 100644 --- a/web_demo.py +++ b/web_demo.py @@ -12,12 +12,19 @@ MAX_BOXES = MAX_TURNS * 2 def predict(input, max_length, top_p, temperature, history=None): if history is None: history = [] - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, + flag = True + response = '' + for delta, seq, history in model.chat_stream(tokenizer, input, history, max_length=max_length, top_p=top_p, temperature=temperature): updates = [] - for query, response in history: - updates.append(gr.update(visible=True, value="用户:" + query)) + response += delta + if flag: + updates.append(gr.update(visible=True, value="用户:" + input)) updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response)) + flag = False + else: + updates[-2]=gr.update(visible=True, value="用户:" + input) + updates[-1]=gr.update(visible=True, value="ChatGLM-6B:" + response) if len(updates) < MAX_BOXES: updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) yield [history] + updates