pull/236/merge
Lucien 2023-03-31 11:39:57 +08:00 committed by GitHub
commit c79935cb77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 198 additions and 24 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
__pycache__
.idea
THUDM
*.tar.gz

View File

@ -96,7 +96,7 @@ python cli_demo.py
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
### API部署
## API 部署
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py)
```shell
python api.py
@ -117,6 +117,18 @@ curl -X POST "http://127.0.0.1:8000" \
}
```
### 支持流式返回的 Websocket API
由于上述 API 不支持流式返回,故在 fastapi 的基础上增加了对 websocket 的支持。
首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。
```shell
python websocket_api.py
```
访问 `http://localhost:8000` 即可看到 [websocket_demo.html](./websocket_demo.html)。
## 低成本部署
### 模型量化
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:

View File

@ -96,7 +96,7 @@ python cli_demo.py
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
## API Deployment
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
First install the additional dependency `pip install fastapi uvicorn`. Then run [api.py](api.py) in the repo.
```shell
python api.py
```
@ -116,6 +116,18 @@ The returned value is
}
```
### Stream Response Supported Websocket API
This DEMO showed that you can embed ChatGLM-6B to your own website through websocket. HTML file: [websocket_demo.html](./websocket_demo.html).
First install the additional dependency `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`. Then run [websocket_api.py](./websocket_api.py) in the repo.
```shell
python websocket_api.py
```
Then you can see [websocket_demo.html](./websocket_demo.html) through access `http://localhost:8000` by default.
## Deployment
### Quantization

26
mock_transformers.py Normal file
View File

@ -0,0 +1,26 @@
from abc import ABC
from transformers import GPT2Model, GPT2Config
class AutoTokenizer:
@classmethod
def from_pretrained(cls, *_, **__):
return None
class AutoModel:
@classmethod
def from_pretrained(cls, *_, **__):
class MockModel(GPT2Model, ABC):
@classmethod
def stream_chat(cls, _, query, history) -> list:
from time import sleep
current_response = ''
for i in range(3):
current_response += str(i)
yield current_response, history + [[query, current_response]]
sleep(1)
def cuda(self, *args, **kwargs):
return self
return MockModel(GPT2Config())

View File

@ -9,11 +9,11 @@ MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
def predict(input, max_length, top_p, temperature, history=None):
def predict(query, max_length, top_p, temperature, history=None):
if history is None:
history = []
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
temperature=temperature):
for _, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
temperature=temperature):
updates = []
for query, response in history:
updates.append(gr.update(visible=True, value="用户:" + query))
@ -23,23 +23,28 @@ def predict(input, max_length, top_p, temperature, history=None):
yield [history] + updates
with gr.Blocks() as demo:
state = gr.State([])
text_boxes = []
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
else:
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
def main():
with gr.Blocks() as demo:
state = gr.State([])
text_boxes = []
for i in range(MAX_BOXES):
if i % 2 == 0:
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
else:
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
with gr.Row():
with gr.Column(scale=4):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
container=False)
with gr.Column(scale=1):
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
demo.queue().launch(share=False, inbrowser=True)
with gr.Row():
with gr.Column(scale=4):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
container=False)
with gr.Column(scale=1):
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
button = gr.Button("Generate")
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
demo.queue().launch(share=False, inbrowser=True)
if __name__ == '__main__':
main()

56
websocket_api.py Normal file
View File

@ -0,0 +1,56 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import uvicorn
pretrained = "THUDM/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
model = AutoModel.from_pretrained(pretrained, trust_remote_code=True).half().cuda()
model = model.eval()
app = FastAPI()
app.add_middleware(
CORSMiddleware
)
with open('websocket_demo.html') as f:
html = f.read()
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
input: JSON String of {"query": "", "history": []}
output: JSON String of {"response": "", "history": [], "status": 200}
status 200 stand for response ended, else not
"""
await websocket.accept()
try:
while True:
json_request = await websocket.receive_json()
query = json_request['query']
history = json_request['history']
for response, history in model.stream_chat(tokenizer, query, history=history):
await websocket.send_json({
"response": response,
"history": history,
"status": 202,
})
await websocket.send_json({"status": 200})
except WebSocketDisconnect:
pass
def main():
uvicorn.run(f"{__name__}:app", host='0.0.0.0', port=8000, workers=1)
if __name__ == '__main__':
main()

59
websocket_demo.html Normal file
View File

@ -0,0 +1,59 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<form action="" onsubmit="return false;" id="form">
<label for="messageText"></label>
<input type="text" id="messageText" autocomplete="off"/>
<button type="submit">Send</button>
</form>
<ul id='messageBox'>
</ul>
<script>
let ws = new WebSocket("ws://" + location.host + "/ws");
let history = [];
let last_message_element = null;
function appendMessage(text, sender, dom = null) {
if (dom === null) {
let messageBox = document.getElementById('messageBox');
dom = document.createElement('li');
messageBox.appendChild(dom);
}
dom.innerText = sender + '' + text;
return dom
}
function sendMessage(event) {
if (last_message_element !== null) { // 如果机器人还没回复完
return;
}
let input = document.getElementById("messageText");
if (input.value === "") {
return;
}
let body = {"query": input.value, 'history': history};
ws.send(JSON.stringify(body));
appendMessage(input.value, '用户')
input.value = '';
event.preventDefault();
}
document.getElementById("form").addEventListener('submit', sendMessage)
ws.onmessage = function (event) {
let body = JSON.parse(event.data);
let status = body['status']
if (status === 200) { // 如果回答结束了
last_message_element = null;
} else {
history = body['history']
last_message_element = appendMessage(body['response'], 'ChatGLM-6B', last_message_element)
}
};
</script>
</body>
</html>