mirror of https://github.com/THUDM/ChatGLM-6B
Merge 6f9c0aa228
into 24e24d5d6c
commit
c79935cb77
|
@ -0,0 +1,4 @@
|
|||
__pycache__
|
||||
.idea
|
||||
THUDM
|
||||
*.tar.gz
|
14
README.md
14
README.md
|
@ -96,7 +96,7 @@ python cli_demo.py
|
|||
|
||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
||||
|
||||
### API部署
|
||||
## API 部署
|
||||
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
||||
```shell
|
||||
python api.py
|
||||
|
@ -117,6 +117,18 @@ curl -X POST "http://127.0.0.1:8000" \
|
|||
}
|
||||
```
|
||||
|
||||
### 支持流式返回的 Websocket API
|
||||
|
||||
由于上述 API 不支持流式返回,故在 fastapi 的基础上增加了对 websocket 的支持。
|
||||
|
||||
首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。
|
||||
|
||||
```shell
|
||||
python websocket_api.py
|
||||
```
|
||||
|
||||
访问 `http://localhost:8000` 即可看到 [websocket_demo.html](./websocket_demo.html)。
|
||||
|
||||
## 低成本部署
|
||||
### 模型量化
|
||||
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
||||
|
|
14
README_en.md
14
README_en.md
|
@ -96,7 +96,7 @@ python cli_demo.py
|
|||
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
|
||||
|
||||
## API Deployment
|
||||
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
|
||||
First install the additional dependency `pip install fastapi uvicorn`. Then run [api.py](api.py) in the repo.
|
||||
```shell
|
||||
python api.py
|
||||
```
|
||||
|
@ -116,6 +116,18 @@ The returned value is
|
|||
}
|
||||
```
|
||||
|
||||
### Stream Response Supported Websocket API
|
||||
|
||||
This DEMO showed that you can embed ChatGLM-6B to your own website through websocket. HTML file: [websocket_demo.html](./websocket_demo.html).
|
||||
|
||||
First install the additional dependency `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`. Then run [websocket_api.py](./websocket_api.py) in the repo.
|
||||
|
||||
```shell
|
||||
python websocket_api.py
|
||||
```
|
||||
|
||||
Then you can see [websocket_demo.html](./websocket_demo.html) through access `http://localhost:8000` by default.
|
||||
|
||||
## Deployment
|
||||
|
||||
### Quantization
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from abc import ABC
|
||||
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
class AutoTokenizer:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *_, **__):
|
||||
return None
|
||||
|
||||
|
||||
class AutoModel:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *_, **__):
|
||||
class MockModel(GPT2Model, ABC):
|
||||
@classmethod
|
||||
def stream_chat(cls, _, query, history) -> list:
|
||||
from time import sleep
|
||||
current_response = ''
|
||||
for i in range(3):
|
||||
current_response += str(i)
|
||||
yield current_response, history + [[query, current_response]]
|
||||
sleep(1)
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
return MockModel(GPT2Config())
|
49
web_demo.py
49
web_demo.py
|
@ -9,11 +9,11 @@ MAX_TURNS = 20
|
|||
MAX_BOXES = MAX_TURNS * 2
|
||||
|
||||
|
||||
def predict(input, max_length, top_p, temperature, history=None):
|
||||
def predict(query, max_length, top_p, temperature, history=None):
|
||||
if history is None:
|
||||
history = []
|
||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
for _, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
|
||||
temperature=temperature):
|
||||
updates = []
|
||||
for query, response in history:
|
||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||
|
@ -23,23 +23,28 @@ def predict(input, max_length, top_p, temperature, history=None):
|
|||
yield [history] + updates
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
state = gr.State([])
|
||||
text_boxes = []
|
||||
for i in range(MAX_BOXES):
|
||||
if i % 2 == 0:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||
else:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||
def main():
|
||||
with gr.Blocks() as demo:
|
||||
state = gr.State([])
|
||||
text_boxes = []
|
||||
for i in range(MAX_BOXES):
|
||||
if i % 2 == 0:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||
else:
|
||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||
container=False)
|
||||
with gr.Column(scale=1):
|
||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
button = gr.Button("Generate")
|
||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
demo.queue().launch(share=False, inbrowser=True)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||
container=False)
|
||||
with gr.Column(scale=1):
|
||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
button = gr.Button("Generate")
|
||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||
demo.queue().launch(share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
import uvicorn
|
||||
|
||||
pretrained = "THUDM/chatglm-6b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(pretrained, trust_remote_code=True).half().cuda()
|
||||
model = model.eval()
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware
|
||||
)
|
||||
|
||||
with open('websocket_demo.html') as f:
|
||||
html = f.read()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""
|
||||
input: JSON String of {"query": "", "history": []}
|
||||
output: JSON String of {"response": "", "history": [], "status": 200}
|
||||
status 200 stand for response ended, else not
|
||||
"""
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
json_request = await websocket.receive_json()
|
||||
query = json_request['query']
|
||||
history = json_request['history']
|
||||
for response, history in model.stream_chat(tokenizer, query, history=history):
|
||||
await websocket.send_json({
|
||||
"response": response,
|
||||
"history": history,
|
||||
"status": 202,
|
||||
})
|
||||
await websocket.send_json({"status": 200})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
uvicorn.run(f"{__name__}:app", host='0.0.0.0', port=8000, workers=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Chat</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebSocket Chat</h1>
|
||||
<form action="" onsubmit="return false;" id="form">
|
||||
<label for="messageText"></label>
|
||||
<input type="text" id="messageText" autocomplete="off"/>
|
||||
<button type="submit">Send</button>
|
||||
</form>
|
||||
<ul id='messageBox'>
|
||||
</ul>
|
||||
<script>
|
||||
let ws = new WebSocket("ws://" + location.host + "/ws");
|
||||
let history = [];
|
||||
let last_message_element = null;
|
||||
|
||||
function appendMessage(text, sender, dom = null) {
|
||||
if (dom === null) {
|
||||
let messageBox = document.getElementById('messageBox');
|
||||
dom = document.createElement('li');
|
||||
messageBox.appendChild(dom);
|
||||
}
|
||||
dom.innerText = sender + ':' + text;
|
||||
return dom
|
||||
}
|
||||
|
||||
function sendMessage(event) {
|
||||
if (last_message_element !== null) { // 如果机器人还没回复完
|
||||
return;
|
||||
}
|
||||
let input = document.getElementById("messageText");
|
||||
if (input.value === "") {
|
||||
return;
|
||||
}
|
||||
let body = {"query": input.value, 'history': history};
|
||||
ws.send(JSON.stringify(body));
|
||||
appendMessage(input.value, '用户')
|
||||
input.value = '';
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
document.getElementById("form").addEventListener('submit', sendMessage)
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
let body = JSON.parse(event.data);
|
||||
let status = body['status']
|
||||
if (status === 200) { // 如果回答结束了
|
||||
last_message_element = null;
|
||||
} else {
|
||||
history = body['history']
|
||||
last_message_element = appendMessage(body['response'], 'ChatGLM-6B', last_message_element)
|
||||
}
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
Loading…
Reference in New Issue