From d6bd5be4cd59b08cc38acb4ac72b91f462835505 Mon Sep 17 00:00:00 2001 From: yuchuanliu Date: Wed, 15 Mar 2023 15:20:51 +0800 Subject: [PATCH] feat: add `.gitignore` and `cache_dir` --- .gitignore | 6 ++++++ cli_demo.py | 8 ++++++-- web_demo.py | 13 +++++++++---- 3 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2823a1d --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# override cache dir +.cache + +# for ide +.vscode +.idea \ No newline at end of file diff --git a/cli_demo.py b/cli_demo.py index d87f707..5941e2a 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -1,9 +1,13 @@ import os +import pathlib import platform + from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +cache_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), ".cache") + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir).half().cuda() model = model.eval() os_name = platform.system() diff --git a/web_demo.py b/web_demo.py index 315978e..934236b 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,8 +1,13 @@ -from transformers import AutoModel, AutoTokenizer -import gradio as gr +import os +import pathlib -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +import gradio as gr +from transformers import AutoModel, AutoTokenizer + +cache_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), ".cache") + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir).half().cuda() model = model.eval() MAX_TURNS = 20