增加交互式参数选项

pull/151/head
kingzeus 2023-04-02 16:10:52 +08:00
parent 2aa175710f
commit ff6d7fbeeb
2 changed files with 110 additions and 5 deletions

View File

@ -3,9 +3,32 @@ import platform
import signal
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
import inquirer
import torch
# 参数
choices_jobType = [("GPU", 1), ("CPU", 2)]
choices_floatType = [("half", 1), ("float", 2)]
choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')]
def print_list(choices, v):
for element in choices:
if element[1] == v:
return element[0]
return None
def print_confirm(v):
if v:
return ''
else:
return ''
def print_confirm2(display, v1, v2=True, v3=True):
if v1 and v2 and v3:
return display
else:
return ''
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
@ -25,7 +48,29 @@ def signal_handler(signal, frame):
stop_stream = True
def main():
def main(answers):
model_name = answers['path'] + answers['model']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# 精度设置
if answers['float_type'] == 2:
model = model.float()
else:
model = model.half()
# 设备设置
if answers['job_type'] == 1:
if os_name == 'Darwin':
model = model.to("mps")
else:
model = model.cuda()
model = model.eval()
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型输入内容即可进行对话clear 清空对话历史stop 终止程序")
@ -54,4 +99,63 @@ def main():
if __name__ == "__main__":
main()
isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False)
# 设置选项
questions = [
inquirer.List(
"job_type",
message="选择运行类型?",
default=1 if isGPUSport else 2,
choices=choices_jobType,
# 如果支持GPU,默认GPU
# 如果不支持GPU的话默认CPU不显示
ignore= not isGPUSport,
),
inquirer.List(
"float_type",
message="选择浮点精度?",
# mac mps半精度容易报错默认float
# 默认使用half
default=2 if os_name == 'Darwin' else 1,
choices=choices_floatType,
),
inquirer.Confirm(
"isLocal",
message="是否使用本地模型",
default=True,
),
inquirer.Text(
"path",
message="设置模型路径",
# 使用本地模型的话,可以设置目录
default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/',
ignore=lambda answer: not answer['isLocal'],
),
inquirer.List(
"model",
message="选择模型?",
# mac mps半精度容易报错默认float
# 默认使用half
default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4',
choices=choices_model,
ignore=os_name == 'Darwin',
),
]
# 处理选项
answers = inquirer.prompt(questions)
print('========= 选项 =========')
print('运行类型: %s' % (print_list(choices_jobType, answers['job_type'])))
print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type'])))
print('本地模型: %s' % (print_confirm(answers['isLocal'])))
print('模型: %s%s' % (answers['path'], answers['model']))
if os_name == 'Darwin':
print('----说明-----')
print('MacOS下如果使用GPU报错的话建议')
print('1.安装 PyTorch-Nightlypip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu')
print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float')
print('------------------------')
main(answers)

View File

@ -4,3 +4,4 @@ icetk
cpm_kernels
torch>=1.10
gradio
inquirer