diff --git a/README.md b/README.md index 814277d..29b6855 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,81 @@ curl -X POST "http://127.0.0.1:8000" \ } ``` +### 流式API部署 +首先需要安装额外的依赖`pip install flask` +```shell +python stream_api.py +``` +默认部署在本地的8000端口,共2个接口,需要配合使用,默认使用application/json作为Content-Type,服务说明如下: + +该流式API使用Flask作为载体,利用线程池原理设定了等待队列与处理队列,开发者可根据硬件实际性能决定队列长度,流式响应API在应用于实际开发是用户侧体验更好,整体利用率更高。 + +#### 1、接口1:/chat + + 用于开启一次对话(指一问一答),调用该接口后,应当持续调用《接口2》,使用request_id以流式获取对话响应内容。 + +- 示例请求数据如下, 其中request_id由调用者指定,用于确定对话实体 + +```json +{ + "history": [["你是谁?","我是智能机器人"]], + "query": "你好", + "request_id": "73" +} + +``` + +- 示例响应数据如下:代表正常响应,服务侧开始处理或进行排队 + +```json +{ + "code": 0, + "msg": "start process", +} +``` + +#### 2、接口2:/get_response + +使用request_id获取对话响应内容,本接口应被定时调用直至该接口返回的is_finished = True,说明本次对话已经推理完毕。 + +- 示例请求数据如下,其中request_id为接口1中指定 + +``` +{ + "request_id": "73" +} +``` + +- 示例响应数据1如下:(代表该请求仍在等待队列中,尚未开始被推理) + +``` +{ + "code": 0, + "msg": "success", + "response": { + "is_finished": false, + "is_handled": false, + "response": "", + "timestamp": 1679813631.926929 + } +} +``` + +- 示例响应数据2如下:(代表该请求已经进入推理队列,尚未推理完成) + +``` +{ + "code": 0, + "msg": "success", + "response": { + "is_finished": false, + "is_handled": true, + "response": "我是智能机器人,请问", + "timestamp": 1679813631.926929 + } +} +``` + ## 低成本部署 ### 模型量化 默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下: diff --git a/README_en.md b/README_en.md index 7b84f9c..935f8f5 100644 --- a/README_en.md +++ b/README_en.md @@ -116,6 +116,84 @@ The returned value is } ``` +## Stream API Deployment + +First install the additional dependency `pip install flask` Then run stream_api.py in the repo + +```shell +python stream_api.py +``` + +By default the api runs at the `8000` port of the local machine. A total of 2 interfaces, need to be used together, default use application/json as Content-Type, service description is as follows: + +The streaming API uses Flask as a carrier to set the waiting queue and processing queue using the thread pool principle, and developers can determine the queue length according to the actual performance of the hardware. + +#### 1、Interface 1:/chat + +To open a dialogue (referring to a question and an answer), after calling this API, you should continuously(maybe 1 second interval) call Interface 2 and use the `request_id` to obtain the conversation response content in streaming. + +- Example request data is as follows, where `request_id` is specified by the caller and is used to determine the conversation entity + +```json +{ + "history": [["你是谁?","我是智能机器人"]], + "query": "你好", + "request_id": "73" +} + +``` + +- The sample response data is as follows: it represents a normal response, and the service side starts processing or queues + +```json +{ + "code": 0, + "msg": "start process", +} +``` + +#### 2、Interface 2:/get_response + +Use `request_id` to obtain the response content of the dialogue, and this API should be called regularly until the `is_finished = True` returned by the interface, indicating that the conversation has been inferred. + +- The sample request data is as follows, where `request_id` specified in Interface 1 + +``` +{ + "request_id": "73" +} +``` + +- Example response data 1 is as follows: (Indicates that the request is still waiting in the queue and has not yet started being inferred) + +``` +{ + "code": 0, + "msg": "success", + "response": { + "is_finished": false, + "is_handled": false, + "response": "", + "timestamp": 1679813631.926929 + } +} +``` + +- The sample response data 2 is as follows: (It means that the request has entered the inference queue and has not yet been inferenced) + +``` +{ + "code": 0, + "msg": "success", + "response": { + "is_finished": false, + "is_handled": true, + "response": "我是智能机器人,请问", + "timestamp": 1679813631.926929 + } +} +``` + ## Deployment ### Quantization diff --git a/stream_api.py b/stream_api.py new file mode 100644 index 0000000..a5df2f6 --- /dev/null +++ b/stream_api.py @@ -0,0 +1,103 @@ +from transformers import AutoTokenizer, AutoModel +from threading import Thread +import time +import sched +from flask import Flask, request, jsonify +from multiprocessing.pool import ThreadPool + +tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +model = model.eval() + +app = Flask(__name__) + +handling_list_max_length = 10 # 最大处理数 +waiting_list_max_length = 20 # 最大等待队列数 + +current_handling = [] +handling_history = {} +pool = ThreadPool(handling_list_max_length) + + +def chat(request_id, query, history): + global current_handling + handling_history[request_id]["is_handled"] = True + for response, history in model.stream_chat(tokenizer, query, history=history): + handling_history[request_id]["response"] = response + + handling_history[request_id]["is_finished"] = True + current_handling.remove(request_id) + + +@app.route('/chat', methods=['POST']) +def query(): + global current_handling + # 获取 POST 请求中的参数 + data = request.get_json() + request_id = data.get('request_id') + history = data.get('history', []) + query = data.get('query') + + # 当正在处理的请求数大于最大处理数时,返回繁忙 + if len(current_handling) > (handling_list_max_length + waiting_list_max_length): + return jsonify({'code': 100, 'msg': 'busy now'}) + + if request_id in handling_history: + return jsonify({'code': 101, 'msg': 'request_id already exists'}) + + current_handling.append(request_id) + handling_history[request_id] = { + "timestamp": time.time(), + "response": "", + "is_finished": False, + "is_handled": False + } + + history_data = [] + for each in history: + history_data.append((each[0], each[1])) + + # 开启线程池进行推理 + pool.apply_async(chat, args=(request_id, query, history_data)) + + # 没有匹配项返回空 + return jsonify({'code': 0, 'msg': 'start process'}) + + +@app.route('/get_response', methods=['POST']) +def getResponse(): + data = request.get_json() + request_id = data.get('request_id') + + if not request_id in handling_history: + return jsonify({'code': 102, 'msg': 'request_id not exists'}) + + return jsonify({'code': 0, 'msg': 'success', 'response': handling_history[request_id]}) + + +def clearHistory(): + # 定时清楚处理history,以防堆叠 + global handling_history + now = time.time() + need_delete = [] + for request_id in handling_history: + if now - handling_history[request_id]["timestamp"] > 60*60*1000 and handling_history[request_id]["is_finished"]: + need_delete.append(request_id) + for request_id in need_delete: + del handling_history[request_id] + + startClean() + + +def startClean(): + s = sched.scheduler(time.time, time.sleep) + s.enter(60, 1, clearHistory, ()) + s.run() + + +if __name__ == '__main__': + cleanT = Thread(target=startClean) + cleanT.start() + app.run(debug=False, port=8000, host='0.0.0.0')