ChatGLM-6B/cli_demo.py

162 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import inquirer
import torch
# 参数
choices_jobType = [("GPU", 1), ("CPU", 2)]
choices_floatType = [("half", 1), ("float", 2)]
choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')]
def print_list(choices, v):
for element in choices:
if element[1] == v:
return element[0]
return None
def print_confirm(v):
if v:
return ''
else:
return ''
def print_confirm2(display, v1, v2=True, v3=True):
if v1 and v2 and v3:
return display
else:
return ''
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
prompt = "欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序"
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main(answers):
model_name = answers['path'] + answers['model']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# 精度设置
if answers['float_type'] == 2:
model = model.float()
else:
model = model.half()
# 设备设置
if answers['job_type'] == 1:
if os_name == 'Darwin':
model = model.to("mps")
else:
model = model.cuda()
model = model.eval()
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
continue
count = 0
for response, history in model.stream_chat(tokenizer, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history), flush=True)
if __name__ == "__main__":
isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False)
# 设置选项
questions = [
inquirer.List(
"job_type",
message="选择运行类型?",
default=1 if isGPUSport else 2,
choices=choices_jobType,
# 如果支持GPU,默认GPU
# 如果不支持GPU的话默认CPU不显示
ignore= not isGPUSport,
),
inquirer.List(
"float_type",
message="选择浮点精度?",
# mac mps半精度容易报错默认float
# 默认使用half
default=2 if os_name == 'Darwin' else 1,
choices=choices_floatType,
),
inquirer.Confirm(
"isLocal",
message="是否使用本地模型",
default=True,
),
inquirer.Text(
"path",
message="设置模型路径",
# 使用本地模型的话,可以设置目录
default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/',
ignore=lambda answer: not answer['isLocal'],
),
inquirer.List(
"model",
message="选择模型?",
# mac mps半精度容易报错默认float
# 默认使用half
default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4',
choices=choices_model,
ignore=os_name == 'Darwin',
),
]
# 处理选项
answers = inquirer.prompt(questions)
print('========= 选项 =========')
print('运行类型: %s' % (print_list(choices_jobType, answers['job_type'])))
print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type'])))
print('本地模型: %s' % (print_confirm(answers['isLocal'])))
print('模型: %s%s' % (answers['path'], answers['model']))
if os_name == 'Darwin':
print('----说明-----')
print('MacOS下如果使用GPU报错的话建议')
print('1.安装 PyTorch-Nightlypip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu')
print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float')
print('------------------------')
main(answers)