From c81e046c1e1d3bf5737c8de5c4003dedd44bfeea Mon Sep 17 00:00:00 2001 From: xudafeng Date: Fri, 5 May 2023 00:19:43 +0800 Subject: [PATCH] feat: use int4 --- api.py | 4 ++-- cli_demo.py | 4 ++-- web_demo.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api.py b/api.py index 693c70a..aac78b6 100644 --- a/api.py +++ b/api.py @@ -50,7 +50,7 @@ async def create_item(request: Request): if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() model.eval() uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/cli_demo.py b/cli_demo.py index da80fff..8fdb308 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -3,8 +3,8 @@ import platform import signal from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() model = model.eval() os_name = platform.system() diff --git a/web_demo.py b/web_demo.py index 97ea622..bf24438 100644 --- a/web_demo.py +++ b/web_demo.py @@ -2,8 +2,8 @@ from transformers import AutoModel, AutoTokenizer import gradio as gr import mdtex2html -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda() model = model.eval() """Override Chatbot.postprocess""" @@ -98,4 +98,4 @@ with gr.Blocks() as demo: emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) -demo.queue().launch(share=False, inbrowser=True) +demo.queue().launch(share=False, inbrowser=False, server_name="0.0.0.0")