mirror of https://github.com/THUDM/ChatGLM-6B
增加交互式参数选项
parent
2aa175710f
commit
ff6d7fbeeb
114
cli_demo.py
114
cli_demo.py
|
@ -3,9 +3,32 @@ import platform
|
|||
import signal
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
import inquirer
|
||||
import torch
|
||||
# 参数
|
||||
choices_jobType = [("GPU", 1), ("CPU", 2)]
|
||||
choices_floatType = [("half", 1), ("float", 2)]
|
||||
choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')]
|
||||
|
||||
def print_list(choices, v):
|
||||
for element in choices:
|
||||
if element[1] == v:
|
||||
return element[0]
|
||||
return None
|
||||
|
||||
|
||||
def print_confirm(v):
|
||||
if v:
|
||||
return '是'
|
||||
else:
|
||||
return '否'
|
||||
|
||||
|
||||
def print_confirm2(display, v1, v2=True, v3=True):
|
||||
if v1 and v2 and v3:
|
||||
return display
|
||||
else:
|
||||
return ''
|
||||
|
||||
os_name = platform.system()
|
||||
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||
|
@ -25,7 +48,29 @@ def signal_handler(signal, frame):
|
|||
stop_stream = True
|
||||
|
||||
|
||||
def main():
|
||||
def main(answers):
|
||||
model_name = answers['path'] + answers['model']
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||
# 精度设置
|
||||
if answers['float_type'] == 2:
|
||||
model = model.float()
|
||||
else:
|
||||
model = model.half()
|
||||
# 设备设置
|
||||
if answers['job_type'] == 1:
|
||||
if os_name == 'Darwin':
|
||||
model = model.to("mps")
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
model = model.eval()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
history = []
|
||||
global stop_stream
|
||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||
|
@ -54,4 +99,63 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False)
|
||||
# 设置选项
|
||||
questions = [
|
||||
inquirer.List(
|
||||
"job_type",
|
||||
message="选择运行类型?",
|
||||
default=1 if isGPUSport else 2,
|
||||
choices=choices_jobType,
|
||||
# 如果支持GPU,默认GPU
|
||||
# 如果不支持GPU的话,默认CPU,不显示
|
||||
ignore= not isGPUSport,
|
||||
),
|
||||
inquirer.List(
|
||||
"float_type",
|
||||
message="选择浮点精度?",
|
||||
# mac mps半精度容易报错,默认float
|
||||
# 默认使用half
|
||||
default=2 if os_name == 'Darwin' else 1,
|
||||
choices=choices_floatType,
|
||||
|
||||
),
|
||||
inquirer.Confirm(
|
||||
"isLocal",
|
||||
message="是否使用本地模型",
|
||||
default=True,
|
||||
),
|
||||
inquirer.Text(
|
||||
"path",
|
||||
message="设置模型路径",
|
||||
# 使用本地模型的话,可以设置目录
|
||||
default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/',
|
||||
ignore=lambda answer: not answer['isLocal'],
|
||||
),
|
||||
inquirer.List(
|
||||
"model",
|
||||
message="选择模型?",
|
||||
# mac mps半精度容易报错,默认float
|
||||
# 默认使用half
|
||||
default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4',
|
||||
choices=choices_model,
|
||||
ignore=os_name == 'Darwin',
|
||||
),
|
||||
|
||||
]
|
||||
|
||||
# 处理选项
|
||||
answers = inquirer.prompt(questions)
|
||||
|
||||
print('========= 选项 =========')
|
||||
print('运行类型: %s' % (print_list(choices_jobType, answers['job_type'])))
|
||||
print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type'])))
|
||||
print('本地模型: %s' % (print_confirm(answers['isLocal'])))
|
||||
print('模型: %s%s' % (answers['path'], answers['model']))
|
||||
if os_name == 'Darwin':
|
||||
print('----说明-----')
|
||||
print('MacOS下,如果使用GPU报错的话,建议:')
|
||||
print('1.安装 PyTorch-Nightly:pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu')
|
||||
print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float')
|
||||
print('------------------------')
|
||||
main(answers)
|
||||
|
|
|
@ -4,3 +4,4 @@ icetk
|
|||
cpm_kernels
|
||||
torch>=1.10
|
||||
gradio
|
||||
inquirer
|
||||
|
|
Loading…
Reference in New Issue