mirror of https://github.com/THUDM/ChatGLM-6B
增加api流式接口; 整理demo和api抽到一个新目录; 修改readme
parent
aeced3619b
commit
e2039a8b87
62
README.md
62
README.md
|
@ -90,47 +90,56 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
|
||||||
|
|
||||||
## Demo & API
|
## Demo & API
|
||||||
|
|
||||||
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:
|
### 网页版(基于gradio) Demo
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
首先安装 Gradio:`pip install gradio mdtex2html`,然后运行仓库中的 [web_demo_gradio.py](demo_and_api/web_demo_gradio.py):
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/THUDM/ChatGLM-6B
|
cd demo_and_api
|
||||||
cd ChatGLM-6B
|
python web_demo_gradio.py
|
||||||
```
|
|
||||||
|
|
||||||
#### 网页版 Demo
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python web_demo.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
|
程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
|
||||||
|
|
||||||
感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
|
### 网页版(基于streamlit) Demo
|
||||||
|
|
||||||
#### 命令行 Demo
|
首先安装 Streamlit: `pip install streamlit streamlit-chat`,然后运行仓库中的 [web_demo_streamlit.py](demo_and_api/web_demo_streamlit.py):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd demo_and_api
|
||||||
|
streamlit run web_demo_streamlit.py --server.port 6006
|
||||||
|
```
|
||||||
|
|
||||||
|
*感谢 [@AdamBear](https://github.com/AdamBear) 贡献的此实现,详见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).*
|
||||||
|
|
||||||
|
### 命令行 Demo
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
运行仓库中 [cli_demo.py](cli_demo.py):
|
运行仓库中 [cli_demo.py](cli_demo.py):
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
cd demo_and_api
|
||||||
python cli_demo.py
|
python cli_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
|
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
|
||||||
|
|
||||||
### API部署
|
### API部署
|
||||||
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
|
首先需要安装额外的依赖 `pip install fastapi uvicorn pydantic`,然后运行仓库中的 [api.py](demo_and_api/api.py):
|
||||||
```shell
|
```shell
|
||||||
|
cd demo_and_api
|
||||||
python api.py
|
python api.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
API支持普通接口(/chat)和流式接口(/stream_chat);
|
||||||
|
流式接口可实现打字机效果,调用方式可参考 [web_demo_streamlit_with_api.py](demo_and_api/web_demo_streamlit_with_api.py):
|
||||||
|
|
||||||
默认部署在本地的 8000 端口,通过 POST 方法进行调用
|
默认部署在本地的 8000 端口,通过 POST 方法进行调用
|
||||||
```shell
|
```shell
|
||||||
curl -X POST "http://127.0.0.1:8000" \
|
curl -X POST "http://127.0.0.1:8000/chat" \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{"prompt": "你好", "history": []}'
|
-d '{"prompt": "你好", "history": []}'
|
||||||
```
|
```
|
||||||
|
@ -144,6 +153,23 @@ curl -X POST "http://127.0.0.1:8000" \
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 网页版(基于streamlit和API) Demo
|
||||||
|
streamlit作为前端,api作为后端,使用了api的流式接口
|
||||||
|
|
||||||
|
启动后端
|
||||||
|
api依赖安装参照前文
|
||||||
|
```shell
|
||||||
|
cd demo_and_api
|
||||||
|
python api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
启动前端
|
||||||
|
streamlit依赖安装参照前文
|
||||||
|
```shell
|
||||||
|
cd demo_and_api
|
||||||
|
streamlit run web_demo_streamlit_with_api.py --server.port 6006
|
||||||
|
```
|
||||||
|
|
||||||
## 低成本部署
|
## 低成本部署
|
||||||
### 模型量化
|
### 模型量化
|
||||||
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
||||||
|
@ -211,7 +237,7 @@ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
|
||||||
|
|
||||||
## ChatGLM-6B 示例
|
## ChatGLM-6B 示例
|
||||||
|
|
||||||
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
以下是一些使用 `web_demo_gradio.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
||||||
|
|
||||||
<details><summary><b>自我认知</b></summary>
|
<details><summary><b>自我认知</b></summary>
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ cd ChatGLM-6B
|
||||||
|
|
||||||
#### Web Demo
|
#### Web Demo
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):
|
Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):
|
||||||
|
|
||||||
|
|
56
api.py
56
api.py
|
@ -1,56 +0,0 @@
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import uvicorn, json, datetime
|
|
||||||
import torch
|
|
||||||
|
|
||||||
DEVICE = "cuda"
|
|
||||||
DEVICE_ID = "0"
|
|
||||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
with torch.cuda.device(CUDA_DEVICE):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
|
||||||
async def create_item(request: Request):
|
|
||||||
global model, tokenizer
|
|
||||||
json_post_raw = await request.json()
|
|
||||||
json_post = json.dumps(json_post_raw)
|
|
||||||
json_post_list = json.loads(json_post)
|
|
||||||
prompt = json_post_list.get('prompt')
|
|
||||||
history = json_post_list.get('history')
|
|
||||||
max_length = json_post_list.get('max_length')
|
|
||||||
top_p = json_post_list.get('top_p')
|
|
||||||
temperature = json_post_list.get('temperature')
|
|
||||||
response, history = model.chat(tokenizer,
|
|
||||||
prompt,
|
|
||||||
history=history,
|
|
||||||
max_length=max_length if max_length else 2048,
|
|
||||||
top_p=top_p if top_p else 0.7,
|
|
||||||
temperature=temperature if temperature else 0.95)
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
answer = {
|
|
||||||
"response": response,
|
|
||||||
"history": history,
|
|
||||||
"status": 200,
|
|
||||||
"time": time
|
|
||||||
}
|
|
||||||
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
|
|
||||||
print(log)
|
|
||||||
torch_gc()
|
|
||||||
return answer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
|
||||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
|
||||||
model.eval()
|
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import uvicorn, json, datetime
|
||||||
|
import torch
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DEVICE_ID = "0"
|
||||||
|
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class Params(BaseModel):
|
||||||
|
prompt: str = 'hello'
|
||||||
|
history: list[list[str]] = []
|
||||||
|
max_length: int = 2048
|
||||||
|
top_p: float = 0.7
|
||||||
|
temperature: float = 0.95
|
||||||
|
|
||||||
|
class Answer(BaseModel):
|
||||||
|
status: int = 200
|
||||||
|
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
response: str
|
||||||
|
history: list[list[str]] = []
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
with torch.cuda.device(CUDA_DEVICE):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
async def create_chat(params: Params):
|
||||||
|
global model, tokenizer
|
||||||
|
response, history = model.chat(tokenizer,
|
||||||
|
params.prompt,
|
||||||
|
history=params.history,
|
||||||
|
max_length=params.max_length,
|
||||||
|
top_p=params.top_p,
|
||||||
|
temperature=params.temperature)
|
||||||
|
answer_ok = Answer(response=response, history=history)
|
||||||
|
print(answer_ok.json())
|
||||||
|
torch_gc()
|
||||||
|
return answer_ok
|
||||||
|
|
||||||
|
async def create_stream_chat(params: Params):
|
||||||
|
global model, tokenizer
|
||||||
|
for response, history in model.stream_chat(tokenizer,
|
||||||
|
params.prompt,
|
||||||
|
history=params.history,
|
||||||
|
max_length=params.max_length,
|
||||||
|
top_p=params.top_p,
|
||||||
|
temperature=params.temperature):
|
||||||
|
answer_ok = Answer(response=response, history=history)
|
||||||
|
# print(answer_ok.json())
|
||||||
|
yield "\ndata: " + json.dumps(answer_ok.json())
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
@app.post("/chat")
|
||||||
|
async def post_chat(params: Params):
|
||||||
|
answer = await create_chat(params)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
@app.post("/stream_chat")
|
||||||
|
async def post_stream_chat(params: Params):
|
||||||
|
return StreamingResponse(create_stream_chat(params))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||||
|
model.eval()
|
||||||
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
@ -0,0 +1,7 @@
|
||||||
|
pydantic
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
gradio
|
||||||
|
mdtex2html
|
||||||
|
streamlit
|
||||||
|
streamlit-chat
|
|
@ -0,0 +1,72 @@
|
||||||
|
import streamlit as st
|
||||||
|
from streamlit_chat import message
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="ChatGLM-6b 演示",
|
||||||
|
page_icon=":robot:"
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_TURNS = 20
|
||||||
|
MAX_BOXES = MAX_TURNS * 2
|
||||||
|
url = "http://localhost:8000/stream_chat"
|
||||||
|
|
||||||
|
|
||||||
|
def predict(input, max_length, top_p, temperature, history=None):
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
with container:
|
||||||
|
if len(history) > 0:
|
||||||
|
for i, (query, response) in enumerate(history):
|
||||||
|
message(query, avatar_style="big-smile", key=str(i) + "_user")
|
||||||
|
message(response, avatar_style="bottts", key=str(i))
|
||||||
|
|
||||||
|
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
|
||||||
|
st.write("AI正在回复:")
|
||||||
|
with st.empty():
|
||||||
|
req = {
|
||||||
|
"prompt": input,
|
||||||
|
"history": history,
|
||||||
|
"max_length": max_length,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
res = requests.post(url=url,json=req,stream=True)
|
||||||
|
for line in res.iter_lines(delimiter=b'\ndata: '):
|
||||||
|
line = line.decode(encoding='utf-8')
|
||||||
|
if line.strip() == '':
|
||||||
|
continue;
|
||||||
|
response_json = json.loads(json.loads(line))
|
||||||
|
response = response_json['response']
|
||||||
|
history = response_json['history']
|
||||||
|
st.write(response)
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
container = st.container()
|
||||||
|
|
||||||
|
# create a prompt text for the text generation
|
||||||
|
prompt_text = st.text_area(label="用户命令输入",
|
||||||
|
height = 100,
|
||||||
|
placeholder="请在这儿输入您的命令")
|
||||||
|
|
||||||
|
max_length = st.sidebar.slider(
|
||||||
|
'max_length', 0, 4096, 2048, step=1
|
||||||
|
)
|
||||||
|
top_p = st.sidebar.slider(
|
||||||
|
'top_p', 0.0, 1.0, 0.6, step=0.01
|
||||||
|
)
|
||||||
|
temperature = st.sidebar.slider(
|
||||||
|
'temperature', 0.0, 1.0, 0.95, step=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'state' not in st.session_state:
|
||||||
|
st.session_state['state'] = []
|
||||||
|
|
||||||
|
if st.button("发送", key="predict"):
|
||||||
|
with st.spinner("AI正在思考,请稍等........"):
|
||||||
|
# text generation
|
||||||
|
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
|
|
@ -2,7 +2,5 @@ protobuf
|
||||||
transformers==4.27.1
|
transformers==4.27.1
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
gradio
|
|
||||||
mdtex2html
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
accelerate
|
accelerate
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
Before Width: | Height: | Size: 587 KiB After Width: | Height: | Size: 587 KiB |
Loading…
Reference in New Issue