mirror of https://github.com/THUDM/ChatGLM-6B
增加api流式接口; 整理demo和api抽到一个新目录; 修改readme
parent
aeced3619b
commit
e2039a8b87
62
README.md
62
README.md
|
@ -90,47 +90,56 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
|
|||
|
||||
## Demo & API
|
||||
|
||||
我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:
|
||||
### 网页版(基于gradio) Demo
|
||||
|
||||

|
||||
|
||||
首先安装 Gradio:`pip install gradio mdtex2html`,然后运行仓库中的 [web_demo_gradio.py](demo_and_api/web_demo_gradio.py):
|
||||
|
||||
```shell
|
||||
git clone https://github.com/THUDM/ChatGLM-6B
|
||||
cd ChatGLM-6B
|
||||
```
|
||||
|
||||
#### 网页版 Demo
|
||||
|
||||

|
||||
|
||||
首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py):
|
||||
|
||||
```shell
|
||||
python web_demo.py
|
||||
cd demo_and_api
|
||||
python web_demo_gradio.py
|
||||
```
|
||||
|
||||
程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。
|
||||
|
||||
感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
|
||||
### 网页版(基于streamlit) Demo
|
||||
|
||||
#### 命令行 Demo
|
||||
首先安装 Streamlit: `pip install streamlit streamlit-chat`,然后运行仓库中的 [web_demo_streamlit.py](demo_and_api/web_demo_streamlit.py):
|
||||
|
||||
```shell
|
||||
cd demo_and_api
|
||||
streamlit run web_demo_streamlit.py --server.port 6006
|
||||
```
|
||||
|
||||
*感谢 [@AdamBear](https://github.com/AdamBear) 贡献的此实现,详见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).*
|
||||
|
||||
### 命令行 Demo
|
||||
|
||||

|
||||
|
||||
运行仓库中 [cli_demo.py](cli_demo.py):
|
||||
|
||||
```shell
|
||||
cd demo_and_api
|
||||
python cli_demo.py
|
||||
```
|
||||
|
||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。
|
||||
|
||||
### API部署
|
||||
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
|
||||
首先需要安装额外的依赖 `pip install fastapi uvicorn pydantic`,然后运行仓库中的 [api.py](demo_and_api/api.py):
|
||||
```shell
|
||||
cd demo_and_api
|
||||
python api.py
|
||||
```
|
||||
|
||||
API支持普通接口(/chat)和流式接口(/stream_chat);
|
||||
流式接口可实现打字机效果,调用方式可参考 [web_demo_streamlit_with_api.py](demo_and_api/web_demo_streamlit_with_api.py):
|
||||
|
||||
默认部署在本地的 8000 端口,通过 POST 方法进行调用
|
||||
```shell
|
||||
curl -X POST "http://127.0.0.1:8000" \
|
||||
curl -X POST "http://127.0.0.1:8000/chat" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"prompt": "你好", "history": []}'
|
||||
```
|
||||
|
@ -144,6 +153,23 @@ curl -X POST "http://127.0.0.1:8000" \
|
|||
}
|
||||
```
|
||||
|
||||
### 网页版(基于streamlit和API) Demo
|
||||
streamlit作为前端,api作为后端,使用了api的流式接口
|
||||
|
||||
启动后端
|
||||
api依赖安装参照前文
|
||||
```shell
|
||||
cd demo_and_api
|
||||
python api.py
|
||||
```
|
||||
|
||||
启动前端
|
||||
streamlit依赖安装参照前文
|
||||
```shell
|
||||
cd demo_and_api
|
||||
streamlit run web_demo_streamlit_with_api.py --server.port 6006
|
||||
```
|
||||
|
||||
## 低成本部署
|
||||
### 模型量化
|
||||
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
||||
|
@ -211,7 +237,7 @@ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
|
|||
|
||||
## ChatGLM-6B 示例
|
||||
|
||||
以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
||||
以下是一些使用 `web_demo_gradio.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
|
||||
|
||||
<details><summary><b>自我认知</b></summary>
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ cd ChatGLM-6B
|
|||
|
||||
#### Web Demo
|
||||
|
||||

|
||||

|
||||
|
||||
Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):
|
||||
|
||||
|
|
56
api.py
56
api.py
|
@ -1,56 +0,0 @@
|
|||
from fastapi import FastAPI, Request
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import uvicorn, json, datetime
|
||||
import torch
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE_ID = "0"
|
||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def create_item(request: Request):
|
||||
global model, tokenizer
|
||||
json_post_raw = await request.json()
|
||||
json_post = json.dumps(json_post_raw)
|
||||
json_post_list = json.loads(json_post)
|
||||
prompt = json_post_list.get('prompt')
|
||||
history = json_post_list.get('history')
|
||||
max_length = json_post_list.get('max_length')
|
||||
top_p = json_post_list.get('top_p')
|
||||
temperature = json_post_list.get('temperature')
|
||||
response, history = model.chat(tokenizer,
|
||||
prompt,
|
||||
history=history,
|
||||
max_length=max_length if max_length else 2048,
|
||||
top_p=top_p if top_p else 0.7,
|
||||
temperature=temperature if temperature else 0.95)
|
||||
now = datetime.datetime.now()
|
||||
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
answer = {
|
||||
"response": response,
|
||||
"history": history,
|
||||
"status": 200,
|
||||
"time": time
|
||||
}
|
||||
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
|
||||
print(log)
|
||||
torch_gc()
|
||||
return answer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model.eval()
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
@ -0,0 +1,73 @@
|
|||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from pydantic import BaseModel
|
||||
import uvicorn, json, datetime
|
||||
import torch
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE_ID = "0"
|
||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class Params(BaseModel):
|
||||
prompt: str = 'hello'
|
||||
history: list[list[str]] = []
|
||||
max_length: int = 2048
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
class Answer(BaseModel):
|
||||
status: int = 200
|
||||
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
response: str
|
||||
history: list[list[str]] = []
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
async def create_chat(params: Params):
|
||||
global model, tokenizer
|
||||
response, history = model.chat(tokenizer,
|
||||
params.prompt,
|
||||
history=params.history,
|
||||
max_length=params.max_length,
|
||||
top_p=params.top_p,
|
||||
temperature=params.temperature)
|
||||
answer_ok = Answer(response=response, history=history)
|
||||
print(answer_ok.json())
|
||||
torch_gc()
|
||||
return answer_ok
|
||||
|
||||
async def create_stream_chat(params: Params):
|
||||
global model, tokenizer
|
||||
for response, history in model.stream_chat(tokenizer,
|
||||
params.prompt,
|
||||
history=params.history,
|
||||
max_length=params.max_length,
|
||||
top_p=params.top_p,
|
||||
temperature=params.temperature):
|
||||
answer_ok = Answer(response=response, history=history)
|
||||
# print(answer_ok.json())
|
||||
yield "\ndata: " + json.dumps(answer_ok.json())
|
||||
|
||||
torch_gc()
|
||||
|
||||
@app.post("/chat")
|
||||
async def post_chat(params: Params):
|
||||
answer = await create_chat(params)
|
||||
return answer
|
||||
|
||||
@app.post("/stream_chat")
|
||||
async def post_stream_chat(params: Params):
|
||||
return StreamingResponse(create_stream_chat(params))
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
||||
model.eval()
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
@ -0,0 +1,7 @@
|
|||
pydantic
|
||||
fastapi
|
||||
uvicorn
|
||||
gradio
|
||||
mdtex2html
|
||||
streamlit
|
||||
streamlit-chat
|
|
@ -0,0 +1,72 @@
|
|||
import streamlit as st
|
||||
from streamlit_chat import message
|
||||
import requests
|
||||
import json
|
||||
|
||||
st.set_page_config(
|
||||
page_title="ChatGLM-6b 演示",
|
||||
page_icon=":robot:"
|
||||
)
|
||||
|
||||
MAX_TURNS = 20
|
||||
MAX_BOXES = MAX_TURNS * 2
|
||||
url = "http://localhost:8000/stream_chat"
|
||||
|
||||
|
||||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
with container:
|
||||
if len(history) > 0:
|
||||
for i, (query, response) in enumerate(history):
|
||||
message(query, avatar_style="big-smile", key=str(i) + "_user")
|
||||
message(response, avatar_style="bottts", key=str(i))
|
||||
|
||||
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
|
||||
st.write("AI正在回复:")
|
||||
with st.empty():
|
||||
req = {
|
||||
"prompt": input,
|
||||
"history": history,
|
||||
"max_length": max_length,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
}
|
||||
res = requests.post(url=url,json=req,stream=True)
|
||||
for line in res.iter_lines(delimiter=b'\ndata: '):
|
||||
line = line.decode(encoding='utf-8')
|
||||
if line.strip() == '':
|
||||
continue;
|
||||
response_json = json.loads(json.loads(line))
|
||||
response = response_json['response']
|
||||
history = response_json['history']
|
||||
st.write(response)
|
||||
|
||||
return history
|
||||
|
||||
|
||||
container = st.container()
|
||||
|
||||
# create a prompt text for the text generation
|
||||
prompt_text = st.text_area(label="用户命令输入",
|
||||
height = 100,
|
||||
placeholder="请在这儿输入您的命令")
|
||||
|
||||
max_length = st.sidebar.slider(
|
||||
'max_length', 0, 4096, 2048, step=1
|
||||
)
|
||||
top_p = st.sidebar.slider(
|
||||
'top_p', 0.0, 1.0, 0.6, step=0.01
|
||||
)
|
||||
temperature = st.sidebar.slider(
|
||||
'temperature', 0.0, 1.0, 0.95, step=0.01
|
||||
)
|
||||
|
||||
if 'state' not in st.session_state:
|
||||
st.session_state['state'] = []
|
||||
|
||||
if st.button("发送", key="predict"):
|
||||
with st.spinner("AI正在思考,请稍等........"):
|
||||
# text generation
|
||||
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
|
|
@ -2,7 +2,5 @@ protobuf
|
|||
transformers==4.27.1
|
||||
cpm_kernels
|
||||
torch>=1.10
|
||||
gradio
|
||||
mdtex2html
|
||||
sentencepiece
|
||||
accelerate
|
Before Width: | Height: | Size: 1.6 MiB After Width: | Height: | Size: 1.6 MiB |
Before Width: | Height: | Size: 587 KiB After Width: | Height: | Size: 587 KiB |
Loading…
Reference in New Issue