From ff6d7fbeeb337328304964123699ae434e4465cb Mon Sep 17 00:00:00 2001 From: kingzeus Date: Sun, 2 Apr 2023 16:10:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=A4=E4=BA=92=E5=BC=8F?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli_demo.py | 114 ++++++++++++++++++++++++++++++++++++++++++++--- requirements.txt | 1 + 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index da80fff..d5ccf6e 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -3,9 +3,32 @@ import platform import signal from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() -model = model.eval() +import inquirer +import torch +# 参数 +choices_jobType = [("GPU", 1), ("CPU", 2)] +choices_floatType = [("half", 1), ("float", 2)] +choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')] + +def print_list(choices, v): + for element in choices: + if element[1] == v: + return element[0] + return None + + +def print_confirm(v): + if v: + return '是' + else: + return '否' + + +def print_confirm2(display, v1, v2=True, v3=True): + if v1 and v2 and v3: + return display + else: + return '' os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' @@ -25,7 +48,29 @@ def signal_handler(signal, frame): stop_stream = True -def main(): +def main(answers): + model_name = answers['path'] + answers['model'] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + # 精度设置 + if answers['float_type'] == 2: + model = model.float() + else: + model = model.half() + # 设备设置 + if answers['job_type'] == 1: + if os_name == 'Darwin': + model = model.to("mps") + else: + model = model.cuda() + + model = model.eval() + + + + + history = [] global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") @@ -54,4 +99,63 @@ def main(): if __name__ == "__main__": - main() + isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False) + # 设置选项 + questions = [ + inquirer.List( + "job_type", + message="选择运行类型?", + default=1 if isGPUSport else 2, + choices=choices_jobType, + # 如果支持GPU,默认GPU + # 如果不支持GPU的话,默认CPU,不显示 + ignore= not isGPUSport, + ), + inquirer.List( + "float_type", + message="选择浮点精度?", + # mac mps半精度容易报错,默认float + # 默认使用half + default=2 if os_name == 'Darwin' else 1, + choices=choices_floatType, + + ), + inquirer.Confirm( + "isLocal", + message="是否使用本地模型", + default=True, + ), + inquirer.Text( + "path", + message="设置模型路径", + # 使用本地模型的话,可以设置目录 + default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/', + ignore=lambda answer: not answer['isLocal'], + ), + inquirer.List( + "model", + message="选择模型?", + # mac mps半精度容易报错,默认float + # 默认使用half + default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4', + choices=choices_model, + ignore=os_name == 'Darwin', + ), + + ] + + # 处理选项 + answers = inquirer.prompt(questions) + + print('========= 选项 =========') + print('运行类型: %s' % (print_list(choices_jobType, answers['job_type']))) + print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type']))) + print('本地模型: %s' % (print_confirm(answers['isLocal']))) + print('模型: %s%s' % (answers['path'], answers['model'])) + if os_name == 'Darwin': + print('----说明-----') + print('MacOS下,如果使用GPU报错的话,建议:') + print('1.安装 PyTorch-Nightly:pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu') + print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float') + print('------------------------') + main(answers) diff --git a/requirements.txt b/requirements.txt index 00707fe..960e436 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ icetk cpm_kernels torch>=1.10 gradio +inquirer