增加api流式接口; 整理demo和api抽到一个新目录; 修改readme

pull/808/head
liseri 2023-04-25 04:56:30 +00:00
parent aeced3619b
commit e2039a8b87
13 changed files with 197 additions and 77 deletions

View File

@ -90,47 +90,56 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
## Demo & API
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库
### 网页版(基于gradio) Demo
![web-demo](resources/web-demo-gradio.gif)
首先安装 Gradio`pip install gradio mdtex2html`,然后运行仓库中的 [web_demo_gradio.py](demo_and_api/web_demo_gradio.py)
```shell
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
```
#### 网页版 Demo
![web-demo](resources/web-demo.gif)
首先安装 Gradio`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py)
```shell
python web_demo.py
cd demo_and_api
python web_demo_gradio.py
```
程序会运行一个 Web Server并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
### 网页版(基于streamlit) Demo
#### 命令行 Demo
首先安装 Streamlit: `pip install streamlit streamlit-chat`,然后运行仓库中的 [web_demo_streamlit.py](demo_and_api/web_demo_streamlit.py)
```shell
cd demo_and_api
streamlit run web_demo_streamlit.py --server.port 6006
```
*感谢 [@AdamBear](https://github.com/AdamBear) 贡献的此实现,详见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).*
### 命令行 Demo
![cli-demo](resources/cli-demo.png)
运行仓库中 [cli_demo.py](cli_demo.py)
```shell
cd demo_and_api
python cli_demo.py
```
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
### API部署
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py)
首先需要安装额外的依赖 `pip install fastapi uvicorn pydantic`,然后运行仓库中的 [api.py](demo_and_api/api.py)
```shell
cd demo_and_api
python api.py
```
API支持普通接口(/chat)和流式接口(/stream_chat);
流式接口可实现打字机效果,调用方式可参考 [web_demo_streamlit_with_api.py](demo_and_api/web_demo_streamlit_with_api.py)
默认部署在本地的 8000 端口,通过 POST 方法进行调用
```shell
curl -X POST "http://127.0.0.1:8000" \
curl -X POST "http://127.0.0.1:8000/chat" \
-H 'Content-Type: application/json' \
-d '{"prompt": "你好", "history": []}'
```
@ -144,6 +153,23 @@ curl -X POST "http://127.0.0.1:8000" \
}
```
### 网页版(基于streamlit和API) Demo
streamlit作为前端api作为后端使用了api的流式接口
启动后端
api依赖安装参照前文
```shell
cd demo_and_api
python api.py
```
启动前端
streamlit依赖安装参照前文
```shell
cd demo_and_api
streamlit run web_demo_streamlit_with_api.py --server.port 6006
```
## 低成本部署
### 模型量化
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
@ -211,7 +237,7 @@ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
## ChatGLM-6B 示例
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
以下是一些使用 `web_demo_gradio.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
<details><summary><b>自我认知</b></summary>

View File

@ -85,7 +85,7 @@ cd ChatGLM-6B
#### Web Demo
![web-demo](resources/web-demo.png)
![web-demo](resources/web-demo-gradio.png)
Install Gradio `pip install gradio`and run [web_demo.py](web_demo.py):

56
api.py
View File

@ -1,56 +0,0 @@
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

73
demo_and_api/api.py Normal file
View File

@ -0,0 +1,73 @@
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
from pydantic import BaseModel
import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
app = FastAPI()
class Params(BaseModel):
prompt: str = 'hello'
history: list[list[str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
class Answer(BaseModel):
status: int = 200
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
response: str
history: list[list[str]] = []
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
async def create_chat(params: Params):
global model, tokenizer
response, history = model.chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature)
answer_ok = Answer(response=response, history=history)
print(answer_ok.json())
torch_gc()
return answer_ok
async def create_stream_chat(params: Params):
global model, tokenizer
for response, history in model.stream_chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature):
answer_ok = Answer(response=response, history=history)
# print(answer_ok.json())
yield "\ndata: " + json.dumps(answer_ok.json())
torch_gc()
@app.post("/chat")
async def post_chat(params: Params):
answer = await create_chat(params)
return answer
@app.post("/stream_chat")
async def post_stream_chat(params: Params):
return StreamingResponse(create_stream_chat(params))
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -0,0 +1,7 @@
pydantic
fastapi
uvicorn
gradio
mdtex2html
streamlit
streamlit-chat

View File

@ -0,0 +1,72 @@
import streamlit as st
from streamlit_chat import message
import requests
import json
st.set_page_config(
page_title="ChatGLM-6b 演示",
page_icon=":robot:"
)
MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
url = "http://localhost:8000/stream_chat"
def predict(input, max_length, top_p, temperature, history=None):
if history is None:
history = []
with container:
if len(history) > 0:
for i, (query, response) in enumerate(history):
message(query, avatar_style="big-smile", key=str(i) + "_user")
message(response, avatar_style="bottts", key=str(i))
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
st.write("AI正在回复:")
with st.empty():
req = {
"prompt": input,
"history": history,
"max_length": max_length,
"top_p": top_p,
"temperature": temperature
}
res = requests.post(url=url,json=req,stream=True)
for line in res.iter_lines(delimiter=b'\ndata: '):
line = line.decode(encoding='utf-8')
if line.strip() == '':
continue;
response_json = json.loads(json.loads(line))
response = response_json['response']
history = response_json['history']
st.write(response)
return history
container = st.container()
# create a prompt text for the text generation
prompt_text = st.text_area(label="用户命令输入",
height = 100,
placeholder="请在这儿输入您的命令")
max_length = st.sidebar.slider(
'max_length', 0, 4096, 2048, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.6, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)
if 'state' not in st.session_state:
st.session_state['state'] = []
if st.button("发送", key="predict"):
with st.spinner("AI正在思考请稍等........"):
# text generation
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])

View File

@ -2,7 +2,5 @@ protobuf
transformers==4.27.1
cpm_kernels
torch>=1.10
gradio
mdtex2html
sentencepiece
accelerate

View File

Before

Width:  |  Height:  |  Size: 1.6 MiB

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

Before

Width:  |  Height:  |  Size: 587 KiB

After

Width:  |  Height:  |  Size: 587 KiB