pull/241/head
saber 2023-03-26 15:29:15 +08:00
parent d9c45f0286
commit 8101d75ab8
7 changed files with 19 additions and 17 deletions

View File

@ -165,10 +165,11 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal
```shell ```shell
pip install accelerate pip install accelerate
``` ```
```python
from utils import load_mode_and_tokenizer
model, tokenizer = load_mode_and_tokenizer("your local path", num_gpus=2) ```python
from utils import load_model_and_tokenizer
model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2)
``` ```
即可将模型部署到多卡上进行推理。 即可将模型部署到多卡上进行推理。
## ChatGLM-6B 示例 ## ChatGLM-6B 示例

View File

@ -154,10 +154,11 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=Tru
```shell ```shell
pip install accelerate pip install accelerate
``` ```
```python
from utils import load_mode_and_tokenizer
model, tokenizer = load_mode_and_tokenizer("your local path", num_gpus=2) ```python
from utils import load_model_and_tokenizer
model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2)
``` ```
## ChatGLM-6B Examples ## ChatGLM-6B Examples

4
api.py
View File

@ -4,7 +4,7 @@ import json
import uvicorn import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from utils import load_mode_and_tokenizer from utils import load_model_and_tokenizer
app = FastAPI() app = FastAPI()
@ -34,4 +34,4 @@ async def create_item(request: Request):
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1) uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1)
model, tokenizer = load_mode_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)

View File

@ -1,9 +1,9 @@
import os import os
import platform import platform
from utils import load_mode_and_tokenizer from utils import load_model_and_tokenizer
model, tokenizer = load_mode_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
os_name = platform.system() os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear' clear_command = 'cls' if os_name == 'Windows' else 'clear'

View File

@ -54,7 +54,7 @@ def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike],
return model return model
def load_mode_and_tokenizer(checkpoint_path: Union[str, os.PathLike], def load_model_and_tokenizer(checkpoint_path: Union[str, os.PathLike],
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir", multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
num_gpus: int = 1, **kwargs) -> Tuple[AutoModel, AutoTokenizer]: num_gpus: int = 1, **kwargs) -> Tuple[AutoModel, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
from utils import load_mode_and_tokenizer from utils import load_model_and_tokenizer
model, tokenizer = load_mode_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
MAX_TURNS = 20 MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2 MAX_BOXES = MAX_TURNS * 2

View File

@ -1,7 +1,7 @@
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
import streamlit as st import streamlit as st
from streamlit_chat import message from streamlit_chat import message
from utils import load_mode_and_tokenizer from utils import load_model_and_tokenizer
st.set_page_config( st.set_page_config(
page_title="ChatGLM-6b 演示", page_title="ChatGLM-6b 演示",
@ -11,7 +11,7 @@ st.set_page_config(
@st.cache_resource @st.cache_resource
def get_model(): def get_model():
model, tokenizer = load_mode_and_tokenizer("THUDM/chatglm-6b", num_gpus=1) model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
return tokenizer, model return tokenizer, model