diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5ceb386 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +venv diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..0093bb4 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "models/chatglm-6b-int4"] + path = models/chatglm-6b-int4 + url = https://huggingface.co/THUDM/chatglm-6b-int4 +[submodule "models/chatglm-6b"] + path = models/chatglm-6b + url = https://huggingface.co/THUDM/chatglm-6b diff --git a/cli_demo.py b/cli_demo.py index da80fff..d5ccf6e 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -3,9 +3,32 @@ import platform import signal from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() -model = model.eval() +import inquirer +import torch +# 参数 +choices_jobType = [("GPU", 1), ("CPU", 2)] +choices_floatType = [("half", 1), ("float", 2)] +choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')] + +def print_list(choices, v): + for element in choices: + if element[1] == v: + return element[0] + return None + + +def print_confirm(v): + if v: + return '是' + else: + return '否' + + +def print_confirm2(display, v1, v2=True, v3=True): + if v1 and v2 and v3: + return display + else: + return '' os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' @@ -25,7 +48,29 @@ def signal_handler(signal, frame): stop_stream = True -def main(): +def main(answers): + model_name = answers['path'] + answers['model'] + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + # 精度设置 + if answers['float_type'] == 2: + model = model.float() + else: + model = model.half() + # 设备设置 + if answers['job_type'] == 1: + if os_name == 'Darwin': + model = model.to("mps") + else: + model = model.cuda() + + model = model.eval() + + + + + history = [] global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") @@ -54,4 +99,63 @@ def main(): if __name__ == "__main__": - main() + isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False) + # 设置选项 + questions = [ + inquirer.List( + "job_type", + message="选择运行类型?", + default=1 if isGPUSport else 2, + choices=choices_jobType, + # 如果支持GPU,默认GPU + # 如果不支持GPU的话,默认CPU,不显示 + ignore= not isGPUSport, + ), + inquirer.List( + "float_type", + message="选择浮点精度?", + # mac mps半精度容易报错,默认float + # 默认使用half + default=2 if os_name == 'Darwin' else 1, + choices=choices_floatType, + + ), + inquirer.Confirm( + "isLocal", + message="是否使用本地模型", + default=True, + ), + inquirer.Text( + "path", + message="设置模型路径", + # 使用本地模型的话,可以设置目录 + default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/', + ignore=lambda answer: not answer['isLocal'], + ), + inquirer.List( + "model", + message="选择模型?", + # mac mps半精度容易报错,默认float + # 默认使用half + default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4', + choices=choices_model, + ignore=os_name == 'Darwin', + ), + + ] + + # 处理选项 + answers = inquirer.prompt(questions) + + print('========= 选项 =========') + print('运行类型: %s' % (print_list(choices_jobType, answers['job_type']))) + print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type']))) + print('本地模型: %s' % (print_confirm(answers['isLocal']))) + print('模型: %s%s' % (answers['path'], answers['model'])) + if os_name == 'Darwin': + print('----说明-----') + print('MacOS下,如果使用GPU报错的话,建议:') + print('1.安装 PyTorch-Nightly:pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu') + print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float') + print('------------------------') + main(answers) diff --git a/models/chatglm-6b b/models/chatglm-6b new file mode 160000 index 0000000..08bc851 --- /dev/null +++ b/models/chatglm-6b @@ -0,0 +1 @@ +Subproject commit 08bc85104db4e8da2c215a29c469218953056251 diff --git a/models/chatglm-6b-int4 b/models/chatglm-6b-int4 new file mode 160000 index 0000000..7458231 --- /dev/null +++ b/models/chatglm-6b-int4 @@ -0,0 +1 @@ +Subproject commit 7458231b5ac7f19cc49496c35617a4ea66f0533e diff --git a/requirements.txt b/requirements.txt index 00707fe..960e436 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ icetk cpm_kernels torch>=1.10 gradio +inquirer