mirror of https://github.com/THUDM/ChatGLM-6B
Add support for streaming output
parent
6cda36633e
commit
2ed89f3898
28
cli_demo.py
28
cli_demo.py
|
@ -7,7 +7,18 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).ha
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
os_name = platform.system()
|
os_name = platform.system()
|
||||||
|
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(history):
|
||||||
|
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
|
||||||
|
for query, response in history:
|
||||||
|
prompt += f"\n用户:{query}"
|
||||||
|
prompt += f"\nChatGLM-6B:{response}"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
history = []
|
history = []
|
||||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||||
while True:
|
while True:
|
||||||
|
@ -16,9 +27,18 @@ while True:
|
||||||
break
|
break
|
||||||
if query == "clear":
|
if query == "clear":
|
||||||
history = []
|
history = []
|
||||||
command = 'cls' if os_name == 'Windows' else 'clear'
|
os.system(clear_command)
|
||||||
os.system(command)
|
|
||||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||||
continue
|
continue
|
||||||
response, history = model.chat(tokenizer, query, history=history)
|
count = 0
|
||||||
print(f"ChatGLM-6B:{response}")
|
for response, history in model.stream_chat(tokenizer, query, history=history):
|
||||||
|
count += 1
|
||||||
|
if count % 8 == 0:
|
||||||
|
os.system(clear_command)
|
||||||
|
print(build_prompt(history), flush=True)
|
||||||
|
os.system(clear_command)
|
||||||
|
print(build_prompt(history), flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
@ -12,15 +12,15 @@ MAX_BOXES = MAX_TURNS * 2
|
||||||
def predict(input, max_length, top_p, temperature, history=None):
|
def predict(input, max_length, top_p, temperature, history=None):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
response, history = model.chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||||
temperature=temperature)
|
temperature=temperature):
|
||||||
updates = []
|
updates = []
|
||||||
for query, response in history:
|
for query, response in history:
|
||||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||||
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
||||||
if len(updates) < MAX_BOXES:
|
if len(updates) < MAX_BOXES:
|
||||||
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
||||||
return [history] + updates
|
yield [history] + updates
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
|
Loading…
Reference in New Issue