From 03b38d224882a4a045dbc8af5e91a55e88e63937 Mon Sep 17 00:00:00 2001 From: Lucien Date: Sat, 25 Mar 2023 20:10:04 +0800 Subject: [PATCH 1/5] add websocket api to support stream response, add .gitignore, add mock model for convenience --- .gitignore | 2 ++ README.md | 14 ++++++++++- README_en.md | 15 +++++++++++ mock_transformers.py | 26 +++++++++++++++++++ web_demo.py | 49 +++++++++++++++++++----------------- websocket_api.py | 51 ++++++++++++++++++++++++++++++++++++++ websocket_demo.html | 59 ++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 193 insertions(+), 23 deletions(-) create mode 100644 .gitignore create mode 100644 mock_transformers.py create mode 100644 websocket_api.py create mode 100644 websocket_demo.html diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..083d732 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +.idea diff --git a/README.md b/README.md index 58f3751..2824b7d 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ python cli_demo.py 程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入`clear`可以清空对话历史,输入`stop`终止程序。 -### API部署 +## API 部署 首先需要安装额外的依赖`pip install fastapi uvicorn`,然后运行仓库中的[api.py](api.py): ```shell python api.py @@ -112,6 +112,18 @@ curl -X POST "http://127.0.0.1:8000" \ } ``` +### 支持流式返回的 Websocket API + +由于上述 API 部署提供不支持流式返回,故在 FastAPI 的基础上增加了对 Websocket 的支持。 + +首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。 + +```shell +python websocket_api.py +``` + +访问 `http://localhost:8000` 即可看到 [websocket_demo.html](./websocket_demo.html)。 + ## 低成本部署 ### 模型量化 默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下: diff --git a/README_en.md b/README_en.md index 3a8b17b..a01fb8a 100644 --- a/README_en.md +++ b/README_en.md @@ -89,6 +89,9 @@ python cli_demo.py The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program. ## API Deployment + +### Sync API + First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo. ```shell python api.py @@ -109,6 +112,18 @@ The returned value is } ``` +### Stream Response Supported Websocket API + +This DEMO showed that you can embed ChatGLM-6B to your own website through websocket. HTML file: [websocket_demo.html](./websocket_demo.html). + +First install the additional dependency `pip install 'fastapi~=0.95.0' 'websockets~=10.4'`. Then run [websocket_api.py](./websocket_api.py) in the repo. + +```shell +python websocket_api.py +``` + +Then you can see [websocket_demo.html](./websocket_demo.html) through access `http://localhost:8000` by default. + ## Deployment ### Quantization diff --git a/mock_transformers.py b/mock_transformers.py new file mode 100644 index 0000000..b83f06f --- /dev/null +++ b/mock_transformers.py @@ -0,0 +1,26 @@ +from abc import ABC + +from transformers import GPT2Model, GPT2Config +class AutoTokenizer: + @classmethod + def from_pretrained(cls, *_, **__): + return None + + +class AutoModel: + @classmethod + def from_pretrained(cls, *_, **__): + class MockModel(GPT2Model, ABC): + @classmethod + def stream_chat(cls, _, query, history) -> list: + from time import sleep + current_response = '' + for i in range(3): + current_response += str(i) + yield current_response, history + [[query, current_response]] + sleep(1) + + def cuda(self, *args, **kwargs): + return self + + return MockModel(GPT2Config()) diff --git a/web_demo.py b/web_demo.py index 88a6dc8..c7231d1 100644 --- a/web_demo.py +++ b/web_demo.py @@ -9,11 +9,11 @@ MAX_TURNS = 20 MAX_BOXES = MAX_TURNS * 2 -def predict(input, max_length, top_p, temperature, history=None): +def predict(query, max_length, top_p, temperature, history=None): if history is None: history = [] - for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, - temperature=temperature): + for _, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p, + temperature=temperature): updates = [] for query, response in history: updates.append(gr.update(visible=True, value="用户:" + query)) @@ -23,23 +23,28 @@ def predict(input, max_length, top_p, temperature, history=None): yield [history] + updates -with gr.Blocks() as demo: - state = gr.State([]) - text_boxes = [] - for i in range(MAX_BOXES): - if i % 2 == 0: - text_boxes.append(gr.Markdown(visible=False, label="提问:")) - else: - text_boxes.append(gr.Markdown(visible=False, label="回复:")) +def main(): + with gr.Blocks() as demo: + state = gr.State([]) + text_boxes = [] + for i in range(MAX_BOXES): + if i % 2 == 0: + text_boxes.append(gr.Markdown(visible=False, label="提问:")) + else: + text_boxes.append(gr.Markdown(visible=False, label="回复:")) - with gr.Row(): - with gr.Column(scale=4): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( - container=False) - with gr.Column(scale=1): - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) - button = gr.Button("Generate") - button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) -demo.queue().launch(share=False, inbrowser=True) + with gr.Row(): + with gr.Column(scale=4): + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style( + container=False) + with gr.Column(scale=1): + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + button = gr.Button("Generate") + button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes) + demo.queue().launch(share=False, inbrowser=True) + + +if __name__ == '__main__': + main() diff --git a/websocket_api.py b/websocket_api.py new file mode 100644 index 0000000..4aae12d --- /dev/null +++ b/websocket_api.py @@ -0,0 +1,51 @@ +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse +from mock_transformers import AutoTokenizer, AutoModel + +import uvicorn + +pretrained = "THUDM/chatglm-6b" +tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True) +model = AutoModel.from_pretrained(pretrained, trust_remote_code=True).half().cuda() +model = model.eval() +app = FastAPI() + +with open('websocket_demo.html') as f: + html = f.read() + + +@app.get("/") +async def get(): + return HTMLResponse(html) + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """ + input: JSON String of {"query": "", "history": []} + output: JSON String of {"response": "", "history": [], "status": 200} + status 200 stand for response ended, else not + """ + await websocket.accept() + try: + while True: + json_request = await websocket.receive_json() + query = json_request['query'] + history = json_request['history'] + for response, history in model.stream_chat(tokenizer, query, history=history): + await websocket.send_json({ + "response": response, + "history": history, + "status": 202, + }) + await websocket.send_json({"status": 200}) + except WebSocketDisconnect: + pass + + +def main(): + uvicorn.run(f"{__name__}:app", host='0.0.0.0', port=8000, workers=1) + + +if __name__ == '__main__': + main() diff --git a/websocket_demo.html b/websocket_demo.html new file mode 100644 index 0000000..b57814e --- /dev/null +++ b/websocket_demo.html @@ -0,0 +1,59 @@ + + + + Chat + + +

WebSocket Chat

+
+ + + +
+ + + + \ No newline at end of file From a41456a83184c0a48efc09145f63f92f63ae50c2 Mon Sep 17 00:00:00 2001 From: Lucien Date: Sat, 25 Mar 2023 20:19:53 +0800 Subject: [PATCH 2/5] Fix README typo --- README.md | 2 +- README_en.md | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2824b7d..8471fd8 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ curl -X POST "http://127.0.0.1:8000" \ ### 支持流式返回的 Websocket API -由于上述 API 部署提供不支持流式返回,故在 FastAPI 的基础上增加了对 Websocket 的支持。 +由于上述 API 不支持流式返回,故在 fastapi 的基础上增加了对 websocket 的支持。 首先安装额外的依赖 `pip install 'fastapi~=0.95.0' 'websockets~=10.4'`,然后运行 [websocket_api.py](./websocket_api.py) 即可。 diff --git a/README_en.md b/README_en.md index a01fb8a..4c43de5 100644 --- a/README_en.md +++ b/README_en.md @@ -89,10 +89,7 @@ python cli_demo.py The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program. ## API Deployment - -### Sync API - -First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo. +First install the additional dependency `pip install fastapi uvicorn`. Then run [api.py](api.py) in the repo. ```shell python api.py ``` From dc3bb5ee38d92422335612321efcc8656cdff35d Mon Sep 17 00:00:00 2001 From: Lucien Date: Sat, 25 Mar 2023 20:24:31 +0800 Subject: [PATCH 3/5] remove mock --- websocket_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websocket_api.py b/websocket_api.py index 4aae12d..8a33fcd 100644 --- a/websocket_api.py +++ b/websocket_api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse -from mock_transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModel import uvicorn From d38563a73d9c06fe0f5da18e1d40ff6ef76ceee9 Mon Sep 17 00:00:00 2001 From: Lucien Date: Sat, 25 Mar 2023 20:45:53 +0800 Subject: [PATCH 4/5] fix html to support not only localhost access --- websocket_demo.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websocket_demo.html b/websocket_demo.html index b57814e..59b3928 100644 --- a/websocket_demo.html +++ b/websocket_demo.html @@ -13,7 +13,7 @@