feat: use int4

pull/922/head
xudafeng 2023-05-05 00:19:43 +08:00
parent 614211d928
commit c81e046c1e
3 changed files with 7 additions and 7 deletions

4
api.py
View File

@ -50,7 +50,7 @@ async def create_item(request: Request):
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model.eval() model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -3,8 +3,8 @@ import platform
import signal import signal
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval() model = model.eval()
os_name = platform.system() os_name = platform.system()

View File

@ -2,8 +2,8 @@ from transformers import AutoModel, AutoTokenizer
import gradio as gr import gradio as gr
import mdtex2html import mdtex2html
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval() model = model.eval()
"""Override Chatbot.postprocess""" """Override Chatbot.postprocess"""
@ -98,4 +98,4 @@ with gr.Blocks() as demo:
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True) demo.queue().launch(share=False, inbrowser=False, server_name="0.0.0.0")