mirror of https://github.com/THUDM/ChatGLM-6B
add websocket api to support stream response, add .gitignore, add mock model for convenience
parent
963d5645ef
commit
03b38d2248
|
@ -0,0 +1,2 @@
|
||||||
|
__pycache__
|
||||||
|
.idea
|
14
README.md
14
README.md
|
@ -91,7 +91,7 @@ python cli_demo.py
|
||||||
|
|
||||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
||||||
|
|
||||||
### API部署
|
## API 部署
|
||||||
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
||||||
```shell
|
```shell
|
||||||
python api.py
|
python api.py
|
||||||
|
@ -112,6 +112,18 @@ curl -X POST "http://127.0.0.1:8000" \
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 支持流式返回的 Websocket API
|
||||||
|
|
||||||
|
由于上述 API 部署提供不支持流式返回,故在 FastAPI 的基础上增加了对 Websocket 的支持。
|
||||||
|
|
||||||
|
首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python websocket_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
访问 `http://localhost:8000` 即可看到 [websocket_demo.html](./websocket_demo.html)。
|
||||||
|
|
||||||
## 低成本部署
|
## 低成本部署
|
||||||
### 模型量化
|
### 模型量化
|
||||||
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
||||||
|
|
15
README_en.md
15
README_en.md
|
@ -89,6 +89,9 @@ python cli_demo.py
|
||||||
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
|
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
|
||||||
|
|
||||||
## API Deployment
|
## API Deployment
|
||||||
|
|
||||||
|
### Sync API
|
||||||
|
|
||||||
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
|
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
|
||||||
```shell
|
```shell
|
||||||
python api.py
|
python api.py
|
||||||
|
@ -109,6 +112,18 @@ The returned value is
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Stream Response Supported Websocket API
|
||||||
|
|
||||||
|
This DEMO showed that you can embed ChatGLM-6B to your own website through websocket. HTML file: [websocket_demo.html](./websocket_demo.html).
|
||||||
|
|
||||||
|
First install the additional dependency `pip install 'fastapi~=0.95.0' 'websockets~=10.4'`. Then run [websocket_api.py](./websocket_api.py) in the repo.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python websocket_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can see [websocket_demo.html](./websocket_demo.html) through access `http://localhost:8000` by default.
|
||||||
|
|
||||||
## Deployment
|
## Deployment
|
||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
class AutoTokenizer:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *_, **__):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *_, **__):
|
||||||
|
class MockModel(GPT2Model, ABC):
|
||||||
|
@classmethod
|
||||||
|
def stream_chat(cls, _, query, history) -> list:
|
||||||
|
from time import sleep
|
||||||
|
current_response = ''
|
||||||
|
for i in range(3):
|
||||||
|
current_response += str(i)
|
||||||
|
yield current_response, history + [[query, current_response]]
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
def cuda(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
return MockModel(GPT2Config())
|
|
@ -9,10 +9,10 @@ MAX_TURNS = 20
|
||||||
MAX_BOXES = MAX_TURNS * 2
|
MAX_BOXES = MAX_TURNS * 2
|
||||||
|
|
||||||
|
|
||||||
def predict(input, max_length, top_p, temperature, history=None):
|
def predict(query, max_length, top_p, temperature, history=None):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
for _, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
updates = []
|
updates = []
|
||||||
for query, response in history:
|
for query, response in history:
|
||||||
|
@ -23,6 +23,7 @@ def predict(input, max_length, top_p, temperature, history=None):
|
||||||
yield [history] + updates
|
yield [history] + updates
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
state = gr.State([])
|
state = gr.State([])
|
||||||
text_boxes = []
|
text_boxes = []
|
||||||
|
@ -43,3 +44,7 @@ with gr.Blocks() as demo:
|
||||||
button = gr.Button("Generate")
|
button = gr.Button("Generate")
|
||||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||||
demo.queue().launch(share=False, inbrowser=True)
|
demo.queue().launch(share=False, inbrowser=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from mock_transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
pretrained = "THUDM/chatglm-6b"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(pretrained, trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
with open('websocket_demo.html') as f:
|
||||||
|
html = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(html)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
input: JSON String of {"query": "", "history": []}
|
||||||
|
output: JSON String of {"response": "", "history": [], "status": 200}
|
||||||
|
status 200 stand for response ended, else not
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
json_request = await websocket.receive_json()
|
||||||
|
query = json_request['query']
|
||||||
|
history = json_request['history']
|
||||||
|
for response, history in model.stream_chat(tokenizer, query, history=history):
|
||||||
|
await websocket.send_json({
|
||||||
|
"response": response,
|
||||||
|
"history": history,
|
||||||
|
"status": 202,
|
||||||
|
})
|
||||||
|
await websocket.send_json({"status": 200})
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
uvicorn.run(f"{__name__}:app", host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,59 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Chat</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>WebSocket Chat</h1>
|
||||||
|
<form action="" onsubmit="return false;" id="form">
|
||||||
|
<label for="messageText"></label>
|
||||||
|
<input type="text" id="messageText" autocomplete="off"/>
|
||||||
|
<button type="submit">Send</button>
|
||||||
|
</form>
|
||||||
|
<ul id='messageBox'>
|
||||||
|
</ul>
|
||||||
|
<script>
|
||||||
|
let ws = new WebSocket("ws://localhost:8000/ws");
|
||||||
|
let history = [];
|
||||||
|
let last_message_element = null;
|
||||||
|
|
||||||
|
function appendMessage(text, sender, dom = null) {
|
||||||
|
if (dom === null) {
|
||||||
|
let messageBox = document.getElementById('messageBox');
|
||||||
|
dom = document.createElement('li');
|
||||||
|
messageBox.appendChild(dom);
|
||||||
|
}
|
||||||
|
dom.innerText = sender + ':' + text;
|
||||||
|
return dom
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendMessage(event) {
|
||||||
|
if (last_message_element !== null) { // 如果机器人还没回复完
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let input = document.getElementById("messageText");
|
||||||
|
if (input.value === "") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let body = {"query": input.value, 'history': history};
|
||||||
|
ws.send(JSON.stringify(body));
|
||||||
|
appendMessage(input.value, '用户')
|
||||||
|
input.value = '';
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
|
||||||
|
document.getElementById("form").addEventListener('submit', sendMessage)
|
||||||
|
|
||||||
|
ws.onmessage = function (event) {
|
||||||
|
let body = JSON.parse(event.data);
|
||||||
|
let status = body['status']
|
||||||
|
if (status === 200) { // 如果回答结束了
|
||||||
|
last_message_element = null;
|
||||||
|
} else {
|
||||||
|
history = body['history']
|
||||||
|
last_message_element = appendMessage(body['response'], 'ChatGLM-6B', last_message_element)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
Loading…
Reference in New Issue