mirror of https://github.com/THUDM/ChatGLM-6B
Merge 6f9c0aa228
into 24e24d5d6c
commit
c79935cb77
|
@ -0,0 +1,4 @@
|
||||||
|
__pycache__
|
||||||
|
.idea
|
||||||
|
THUDM
|
||||||
|
*.tar.gz
|
14
README.md
14
README.md
|
@ -96,7 +96,7 @@ python cli_demo.py
|
||||||
|
|
||||||
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。
|
||||||
|
|
||||||
### API部署
|
## API 部署
|
||||||
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py):
|
||||||
```shell
|
```shell
|
||||||
python api.py
|
python api.py
|
||||||
|
@ -117,6 +117,18 @@ curl -X POST "http://127.0.0.1:8000" \
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 支持流式返回的 Websocket API
|
||||||
|
|
||||||
|
由于上述 API 不支持流式返回,故在 fastapi 的基础上增加了对 websocket 的支持。
|
||||||
|
|
||||||
|
首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python websocket_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
访问 `http://localhost:8000` 即可看到 [websocket_demo.html](./websocket_demo.html)。
|
||||||
|
|
||||||
## 低成本部署
|
## 低成本部署
|
||||||
### 模型量化
|
### 模型量化
|
||||||
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
|
||||||
|
|
14
README_en.md
14
README_en.md
|
@ -96,7 +96,7 @@ python cli_demo.py
|
||||||
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
|
The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
|
||||||
|
|
||||||
## API Deployment
|
## API Deployment
|
||||||
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
|
First install the additional dependency `pip install fastapi uvicorn`. Then run [api.py](api.py) in the repo.
|
||||||
```shell
|
```shell
|
||||||
python api.py
|
python api.py
|
||||||
```
|
```
|
||||||
|
@ -116,6 +116,18 @@ The returned value is
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Stream Response Supported Websocket API
|
||||||
|
|
||||||
|
This DEMO showed that you can embed ChatGLM-6B to your own website through websocket. HTML file: [websocket_demo.html](./websocket_demo.html).
|
||||||
|
|
||||||
|
First install the additional dependency `pip install 'fastapi~=0.95.0' 'websockets~=10.4' 'uvicorn~=0.21.1'`. Then run [websocket_api.py](./websocket_api.py) in the repo.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python websocket_api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can see [websocket_demo.html](./websocket_demo.html) through access `http://localhost:8000` by default.
|
||||||
|
|
||||||
## Deployment
|
## Deployment
|
||||||
|
|
||||||
### Quantization
|
### Quantization
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
class AutoTokenizer:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *_, **__):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *_, **__):
|
||||||
|
class MockModel(GPT2Model, ABC):
|
||||||
|
@classmethod
|
||||||
|
def stream_chat(cls, _, query, history) -> list:
|
||||||
|
from time import sleep
|
||||||
|
current_response = ''
|
||||||
|
for i in range(3):
|
||||||
|
current_response += str(i)
|
||||||
|
yield current_response, history + [[query, current_response]]
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
def cuda(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
return MockModel(GPT2Config())
|
49
web_demo.py
49
web_demo.py
|
@ -9,11 +9,11 @@ MAX_TURNS = 20
|
||||||
MAX_BOXES = MAX_TURNS * 2
|
MAX_BOXES = MAX_TURNS * 2
|
||||||
|
|
||||||
|
|
||||||
def predict(input, max_length, top_p, temperature, history=None):
|
def predict(query, max_length, top_p, temperature, history=None):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
for _, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
updates = []
|
updates = []
|
||||||
for query, response in history:
|
for query, response in history:
|
||||||
updates.append(gr.update(visible=True, value="用户:" + query))
|
updates.append(gr.update(visible=True, value="用户:" + query))
|
||||||
|
@ -23,23 +23,28 @@ def predict(input, max_length, top_p, temperature, history=None):
|
||||||
yield [history] + updates
|
yield [history] + updates
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
def main():
|
||||||
state = gr.State([])
|
with gr.Blocks() as demo:
|
||||||
text_boxes = []
|
state = gr.State([])
|
||||||
for i in range(MAX_BOXES):
|
text_boxes = []
|
||||||
if i % 2 == 0:
|
for i in range(MAX_BOXES):
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
if i % 2 == 0:
|
||||||
else:
|
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
||||||
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
else:
|
||||||
|
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
||||||
container=False)
|
container=False)
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||||
button = gr.Button("Generate")
|
button = gr.Button("Generate")
|
||||||
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
||||||
demo.queue().launch(share=False, inbrowser=True)
|
demo.queue().launch(share=False, inbrowser=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
pretrained = "THUDM/chatglm-6b"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(pretrained, trust_remote_code=True).half().cuda()
|
||||||
|
model = model.eval()
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware
|
||||||
|
)
|
||||||
|
|
||||||
|
with open('websocket_demo.html') as f:
|
||||||
|
html = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(html)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
input: JSON String of {"query": "", "history": []}
|
||||||
|
output: JSON String of {"response": "", "history": [], "status": 200}
|
||||||
|
status 200 stand for response ended, else not
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
json_request = await websocket.receive_json()
|
||||||
|
query = json_request['query']
|
||||||
|
history = json_request['history']
|
||||||
|
for response, history in model.stream_chat(tokenizer, query, history=history):
|
||||||
|
await websocket.send_json({
|
||||||
|
"response": response,
|
||||||
|
"history": history,
|
||||||
|
"status": 202,
|
||||||
|
})
|
||||||
|
await websocket.send_json({"status": 200})
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
uvicorn.run(f"{__name__}:app", host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,59 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>Chat</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>WebSocket Chat</h1>
|
||||||
|
<form action="" onsubmit="return false;" id="form">
|
||||||
|
<label for="messageText"></label>
|
||||||
|
<input type="text" id="messageText" autocomplete="off"/>
|
||||||
|
<button type="submit">Send</button>
|
||||||
|
</form>
|
||||||
|
<ul id='messageBox'>
|
||||||
|
</ul>
|
||||||
|
<script>
|
||||||
|
let ws = new WebSocket("ws://" + location.host + "/ws");
|
||||||
|
let history = [];
|
||||||
|
let last_message_element = null;
|
||||||
|
|
||||||
|
function appendMessage(text, sender, dom = null) {
|
||||||
|
if (dom === null) {
|
||||||
|
let messageBox = document.getElementById('messageBox');
|
||||||
|
dom = document.createElement('li');
|
||||||
|
messageBox.appendChild(dom);
|
||||||
|
}
|
||||||
|
dom.innerText = sender + ':' + text;
|
||||||
|
return dom
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendMessage(event) {
|
||||||
|
if (last_message_element !== null) { // 如果机器人还没回复完
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let input = document.getElementById("messageText");
|
||||||
|
if (input.value === "") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let body = {"query": input.value, 'history': history};
|
||||||
|
ws.send(JSON.stringify(body));
|
||||||
|
appendMessage(input.value, '用户')
|
||||||
|
input.value = '';
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
|
||||||
|
document.getElementById("form").addEventListener('submit', sendMessage)
|
||||||
|
|
||||||
|
ws.onmessage = function (event) {
|
||||||
|
let body = JSON.parse(event.data);
|
||||||
|
let status = body['status']
|
||||||
|
if (status === 200) { // 如果回答结束了
|
||||||
|
last_message_element = null;
|
||||||
|
} else {
|
||||||
|
history = body['history']
|
||||||
|
last_message_element = appendMessage(body['response'], 'ChatGLM-6B', last_message_element)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
Loading…
Reference in New Issue