mirror of https://github.com/THUDM/ChatGLM-6B
Merge ff6d7fbeeb
into 801b1bb576
commit
7a5ba16878
|
@ -0,0 +1 @@
|
||||||
|
venv
|
|
@ -0,0 +1,6 @@
|
||||||
|
[submodule "models/chatglm-6b-int4"]
|
||||||
|
path = models/chatglm-6b-int4
|
||||||
|
url = https://huggingface.co/THUDM/chatglm-6b-int4
|
||||||
|
[submodule "models/chatglm-6b"]
|
||||||
|
path = models/chatglm-6b
|
||||||
|
url = https://huggingface.co/THUDM/chatglm-6b
|
114
cli_demo.py
114
cli_demo.py
|
@ -3,9 +3,32 @@ import platform
|
||||||
import signal
|
import signal
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
import inquirer
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
import torch
|
||||||
model = model.eval()
|
# 参数
|
||||||
|
choices_jobType = [("GPU", 1), ("CPU", 2)]
|
||||||
|
choices_floatType = [("half", 1), ("float", 2)]
|
||||||
|
choices_model = [("默认(chatglm-6b)", 'chatglm-6b'), ("量化int4(chatglm-6b-int4)", 'chatglm-6b-int4')]
|
||||||
|
|
||||||
|
def print_list(choices, v):
|
||||||
|
for element in choices:
|
||||||
|
if element[1] == v:
|
||||||
|
return element[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def print_confirm(v):
|
||||||
|
if v:
|
||||||
|
return '是'
|
||||||
|
else:
|
||||||
|
return '否'
|
||||||
|
|
||||||
|
|
||||||
|
def print_confirm2(display, v1, v2=True, v3=True):
|
||||||
|
if v1 and v2 and v3:
|
||||||
|
return display
|
||||||
|
else:
|
||||||
|
return ''
|
||||||
|
|
||||||
os_name = platform.system()
|
os_name = platform.system()
|
||||||
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||||
|
@ -25,7 +48,29 @@ def signal_handler(signal, frame):
|
||||||
stop_stream = True
|
stop_stream = True
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(answers):
|
||||||
|
model_name = answers['path'] + answers['model']
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
# 精度设置
|
||||||
|
if answers['float_type'] == 2:
|
||||||
|
model = model.float()
|
||||||
|
else:
|
||||||
|
model = model.half()
|
||||||
|
# 设备设置
|
||||||
|
if answers['job_type'] == 1:
|
||||||
|
if os_name == 'Darwin':
|
||||||
|
model = model.to("mps")
|
||||||
|
else:
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
global stop_stream
|
global stop_stream
|
||||||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
||||||
|
@ -54,4 +99,63 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
isGPUSport = torch.cuda.is_available() or (torch.backends.mps.is_available() if os_name == 'Darwin' else False)
|
||||||
|
# 设置选项
|
||||||
|
questions = [
|
||||||
|
inquirer.List(
|
||||||
|
"job_type",
|
||||||
|
message="选择运行类型?",
|
||||||
|
default=1 if isGPUSport else 2,
|
||||||
|
choices=choices_jobType,
|
||||||
|
# 如果支持GPU,默认GPU
|
||||||
|
# 如果不支持GPU的话,默认CPU,不显示
|
||||||
|
ignore= not isGPUSport,
|
||||||
|
),
|
||||||
|
inquirer.List(
|
||||||
|
"float_type",
|
||||||
|
message="选择浮点精度?",
|
||||||
|
# mac mps半精度容易报错,默认float
|
||||||
|
# 默认使用half
|
||||||
|
default=2 if os_name == 'Darwin' else 1,
|
||||||
|
choices=choices_floatType,
|
||||||
|
|
||||||
|
),
|
||||||
|
inquirer.Confirm(
|
||||||
|
"isLocal",
|
||||||
|
message="是否使用本地模型",
|
||||||
|
default=True,
|
||||||
|
),
|
||||||
|
inquirer.Text(
|
||||||
|
"path",
|
||||||
|
message="设置模型路径",
|
||||||
|
# 使用本地模型的话,可以设置目录
|
||||||
|
default=lambda answer: './models/' if answer['isLocal'] else 'THUDM/',
|
||||||
|
ignore=lambda answer: not answer['isLocal'],
|
||||||
|
),
|
||||||
|
inquirer.List(
|
||||||
|
"model",
|
||||||
|
message="选择模型?",
|
||||||
|
# mac mps半精度容易报错,默认float
|
||||||
|
# 默认使用half
|
||||||
|
default='chatglm-6b' if os_name == 'Darwin' else 'chatglm-6b-int4',
|
||||||
|
choices=choices_model,
|
||||||
|
ignore=os_name == 'Darwin',
|
||||||
|
),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
# 处理选项
|
||||||
|
answers = inquirer.prompt(questions)
|
||||||
|
|
||||||
|
print('========= 选项 =========')
|
||||||
|
print('运行类型: %s' % (print_list(choices_jobType, answers['job_type'])))
|
||||||
|
print('浮点精度: %s' % (print_list(choices_floatType, answers['float_type'])))
|
||||||
|
print('本地模型: %s' % (print_confirm(answers['isLocal'])))
|
||||||
|
print('模型: %s%s' % (answers['path'], answers['model']))
|
||||||
|
if os_name == 'Darwin':
|
||||||
|
print('----说明-----')
|
||||||
|
print('MacOS下,如果使用GPU报错的话,建议:')
|
||||||
|
print('1.安装 PyTorch-Nightly:pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu')
|
||||||
|
print('2.出现 LLVM ERROR: Failed to infer result type(s). 可以把精度设置为float')
|
||||||
|
print('------------------------')
|
||||||
|
main(answers)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 08bc85104db4e8da2c215a29c469218953056251
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 7458231b5ac7f19cc49496c35617a4ea66f0533e
|
|
@ -4,3 +4,4 @@ icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
gradio
|
gradio
|
||||||
|
inquirer
|
||||||
|
|
Loading…
Reference in New Issue