mirror of https://github.com/THUDM/ChatGLM-6B
commit
43b7241e67
11
README.md
11
README.md
|
@ -167,6 +167,17 @@ model = AutoModel.from_pretrained("your local path", trust_remote_code=True).hal
|
||||||
```
|
```
|
||||||
即可使用在 Mac 上使用 GPU 加速模型推理。
|
即可使用在 Mac 上使用 GPU 加速模型推理。
|
||||||
|
|
||||||
|
### 多卡部署
|
||||||
|
```shell
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from utils import load_model_and_tokenizer
|
||||||
|
|
||||||
|
model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2)
|
||||||
|
```
|
||||||
|
即可将模型部署到多卡上进行推理。
|
||||||
## ChatGLM-6B 示例
|
## ChatGLM-6B 示例
|
||||||
|
|
||||||
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
||||||
|
|
12
README_en.md
12
README_en.md
|
@ -156,6 +156,18 @@ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=Tru
|
||||||
|
|
||||||
**For Mac users**: if your encounter the error `RuntimeError: Unknown platform: darwin`, please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
|
**For Mac users**: if your encounter the error `RuntimeError: Unknown platform: darwin`, please refer to this [Issue](https://github.com/THUDM/ChatGLM-6B/issues/6#issuecomment-1470060041).
|
||||||
|
|
||||||
|
### Multi-GPU Deployment
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from utils import load_model_and_tokenizer
|
||||||
|
|
||||||
|
model, tokenizer = load_model_and_tokenizer("your local path", num_gpus=2)
|
||||||
|
```
|
||||||
|
|
||||||
## ChatGLM-6B Examples
|
## ChatGLM-6B Examples
|
||||||
|
|
||||||
The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B.
|
The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B.
|
||||||
|
|
12
api.py
12
api.py
|
@ -1,6 +1,10 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import uvicorn, json, datetime
|
from utils import load_model_and_tokenizer
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -30,6 +34,4 @@ async def create_item(request: Request):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
||||||
model.eval()
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
from utils import load_model_and_tokenizer
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
||||||
model = model.eval()
|
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||||
|
|
||||||
os_name = platform.system()
|
os_name = platform.system()
|
||||||
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
||||||
|
|
|
@ -4,3 +4,4 @@ icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
gradio
|
gradio
|
||||||
|
accelerate
|
|
@ -0,0 +1,81 @@
|
||||||
|
import os
|
||||||
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
|
||||||
|
from torch.nn import Module
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||||
|
# transformer.word_embeddings 占用1层
|
||||||
|
# transformer.final_layernorm 和 lm_head 占用1层
|
||||||
|
# transformer.layers 占用 28 层
|
||||||
|
# 总共30层分配到num_gpus张卡上
|
||||||
|
num_trans_layers = 28
|
||||||
|
per_gpu_layers = 30 / num_gpus
|
||||||
|
|
||||||
|
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
||||||
|
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
||||||
|
# linux下 model.device 会被设置成 lm_head.device
|
||||||
|
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||||
|
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||||
|
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||||
|
device_map = {'transformer.word_embeddings': 0,
|
||||||
|
'transformer.final_layernorm': 0, 'lm_head': 0}
|
||||||
|
|
||||||
|
used = 2
|
||||||
|
gpu_target = 0
|
||||||
|
for i in range(num_trans_layers):
|
||||||
|
if used >= per_gpu_layers:
|
||||||
|
gpu_target += 1
|
||||||
|
used = 0
|
||||||
|
assert gpu_target < num_gpus
|
||||||
|
device_map[f'transformer.layers.{i}'] = gpu_target
|
||||||
|
used += 1
|
||||||
|
|
||||||
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
|
||||||
|
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
||||||
|
device_map: Optional[Dict[str, int]] = None,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizer] = None, **kwargs) -> Module:
|
||||||
|
from accelerate import load_checkpoint_and_dispatch
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
if device_map is None:
|
||||||
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
try:
|
||||||
|
model = load_checkpoint_and_dispatch(
|
||||||
|
model, checkpoint_path, device_map=device_map, offload_folder="offload", offload_state_dict=True).half()
|
||||||
|
except ValueError:
|
||||||
|
# index.json not found
|
||||||
|
print(f"index.json not found, auto fixing and saving model to {multi_gpu_model_cache_dir} ...")
|
||||||
|
|
||||||
|
assert multi_gpu_model_cache_dir is not None, "using auto fix, cache_dir must not be None"
|
||||||
|
model.save_pretrained(multi_gpu_model_cache_dir, max_shard_size='2GB')
|
||||||
|
model = load_checkpoint_and_dispatch(
|
||||||
|
model, multi_gpu_model_cache_dir, device_map=device_map,
|
||||||
|
offload_folder="offload", offload_state_dict=True).half()
|
||||||
|
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(multi_gpu_model_cache_dir)
|
||||||
|
print(f"loading model successfully, you should use checkpoint_path={multi_gpu_model_cache_dir} next time")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 1,
|
||||||
|
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
||||||
|
**kwargs) -> Tuple[Module, PreTrainedTokenizer]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
||||||
|
if num_gpus < 2:
|
||||||
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
else:
|
||||||
|
model = load_model_on_gpus(checkpoint_path, num_gpus=num_gpus,
|
||||||
|
multi_gpu_model_cache_dir=multi_gpu_model_cache_dir,
|
||||||
|
tokenizer=tokenizer, **kwargs)
|
||||||
|
return model, tokenizer
|
|
@ -1,9 +1,8 @@
|
||||||
from transformers import AutoModel, AutoTokenizer
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
from utils import load_model_and_tokenizer
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
||||||
model = model.eval()
|
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||||
|
|
||||||
MAX_TURNS = 20
|
MAX_TURNS = 20
|
||||||
MAX_BOXES = MAX_TURNS * 2
|
MAX_BOXES = MAX_TURNS * 2
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from streamlit_chat import message
|
from streamlit_chat import message
|
||||||
|
from utils import load_model_and_tokenizer
|
||||||
|
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
page_title="ChatGLM-6b 演示",
|
page_title="ChatGLM-6b 演示",
|
||||||
|
@ -11,9 +11,7 @@ st.set_page_config(
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def get_model():
|
def get_model():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
||||||
model = model.eval()
|
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue