|
|
|
@ -1,14 +1,15 @@
|
|
|
|
|
import os |
|
|
|
|
import platform |
|
|
|
|
import signal |
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
|
|
|
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() |
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./model", trust_remote_code=True) |
|
|
|
|
model = AutoModel.from_pretrained("./model", trust_remote_code=True).half().cuda() |
|
|
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
|
os_name = platform.system() |
|
|
|
|
clear_command = 'cls' if os_name == 'Windows' else 'clear' |
|
|
|
|
|
|
|
|
|
stop_stream = False |
|
|
|
|
|
|
|
|
|
def build_prompt(history): |
|
|
|
|
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" |
|
|
|
@ -17,9 +18,13 @@ def build_prompt(history):
|
|
|
|
|
prompt += f"\n\nChatGLM-6B:{response}" |
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
|
def signal_handler(signal, frame): |
|
|
|
|
global stop_stream |
|
|
|
|
stop_stream = True |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
history = [] |
|
|
|
|
global stop_stream |
|
|
|
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") |
|
|
|
|
while True: |
|
|
|
|
query = input("\n用户:") |
|
|
|
@ -32,10 +37,15 @@ def main():
|
|
|
|
|
continue |
|
|
|
|
count = 0 |
|
|
|
|
for response, history in model.stream_chat(tokenizer, query, history=history): |
|
|
|
|
count += 1 |
|
|
|
|
if count % 8 == 0: |
|
|
|
|
os.system(clear_command) |
|
|
|
|
print(build_prompt(history), flush=True) |
|
|
|
|
if stop_stream: |
|
|
|
|
stop_stream = False |
|
|
|
|
break |
|
|
|
|
else: |
|
|
|
|
count += 1 |
|
|
|
|
if count % 8 == 0: |
|
|
|
|
os.system(clear_command) |
|
|
|
|
print(build_prompt(history), flush=True) |
|
|
|
|
signal.signal(signal.SIGINT,signal_handler) |
|
|
|
|
os.system(clear_command) |
|
|
|
|
print(build_prompt(history), flush=True) |
|
|
|
|
|
|
|
|
|