diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2823a1d --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# override cache dir +.cache + +# for ide +.vscode +.idea \ No newline at end of file diff --git a/cli_demo.py b/cli_demo.py index d87f707..5941e2a 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -1,9 +1,13 @@ import os +import pathlib import platform + from transformers import AutoTokenizer, AutoModel -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +cache_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), ".cache") + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir).half().cuda() model = model.eval() os_name = platform.system() diff --git a/web_demo.py b/web_demo.py index 315978e..934236b 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,8 +1,13 @@ -from transformers import AutoModel, AutoTokenizer -import gradio as gr +import os +import pathlib -tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) -model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +import gradio as gr +from transformers import AutoModel, AutoTokenizer + +cache_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), ".cache") + +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, cache_dir=cache_dir).half().cuda() model = model.eval() MAX_TURNS = 20