From 8d770f1328c2877b688b99c4347d2bf67f08f21e Mon Sep 17 00:00:00 2001 From: PorYoung Date: Sat, 22 Apr 2023 20:08:03 +0800 Subject: [PATCH] feat: chatglm stream api --- cli_chatglm_with_api.py | 157 ++++++++++++++++++++++++++++++++++++++++ stream_flask_api.py | 152 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100755 cli_chatglm_with_api.py create mode 100644 stream_flask_api.py diff --git a/cli_chatglm_with_api.py b/cli_chatglm_with_api.py new file mode 100755 index 0000000..f839010 --- /dev/null +++ b/cli_chatglm_with_api.py @@ -0,0 +1,157 @@ +#!/usr/bin/python3 +import os +import platform +import requests +import json +import time + +os_name = platform.system() +clear_command = "cls" if os_name == "Windows" else "clear" +stream = True +api_url = os.environ.get("ChatGLM_API", "http://127.0.0.1:8888") + "/chat" +max_history_len = 6 +max_role_history_len = 6 +welcome_text = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,role 设定角色,clear 清空对话历史,stop 终止程序" +failed_resp_test = "AI 我呀,有点想不通了..." +secret = os.environ.get("ChatGLM_SECRET", "721d95ac31da59fa022ec8c12f72f597") +type_wait_time = 0.05 + +headers = { + "Content-Type": "application/json", + "Authorization": secret, +} + + +def request_chat(data): + response = requests.post(api_url, headers=headers, json=data) + res = None + if response.status_code == 200: + res = response.json() + else: + print( + "Request failed with code {}.".format(response.status_code), + end="\n\n", + ) + return res + + +def request_stream_chat(data): + response = requests.post(api_url, headers=headers, json=data, stream=True) + if response.status_code == 200: + for lines in response.iter_lines(decode_unicode=True): + if lines: + data = json.loads(lines) + if data.get("status") == 200: + if data.get("stop", True): + break + yield data + else: + print( + f"\n\nSystem Error: Status {data.get('status')}", + end="\n\n", + ) + break + else: + print(f"\n\nSystem Error: Status {response.status_code}") + + +def main(): + global stream + role_history = [] + history = [] + print(welcome_text, end="\n\n") + while True: + query = input("\n用户:") + print("", end="\n\n") + + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + role_history = [] + os.system(clear_command) + print(welcome_text, end="\n\n") + continue + if query.strip() == "role": + print( + f"请输入角色设定(注意提问方式),exit 取消设定, ok 完成设定,不超过{max_role_history_len}", + end="\n\n", + ) + cancled = False + new_role = [] + for i in range(1, max_role_history_len + 1): + query = input("\n设定 " + str(i) + ":") + if query.strip() == "exit": + cancled = True + break + if query.strip() == "ok": + break + req_data = { + "prompt": query, + "history": new_role, + "stream": False, + } + + print("\n\nChatGLM-6B 记录中......", end="\n\n") + + if stream: + req_data["stream"] = True + response = None + for res_data in request_stream_chat(req_data): + if res_data: + response = res_data + else: + response = request_chat(req_data) + if response: + new_role = response.get("history", history) + else: + print("该设定失败!", end="\n\n") + if not cancled: + role_history = new_role + history = role_history + print("设定角色成功!", end="\n\n") + print(history, end="\n\n") + continue + + if len(role_history) > 0: + if len(history) > max_history_len + len(role_history): + history = role_history + history[-max_history_len:] + else: + history = history[-max_history_len:] + + print("ChatGLM-6B:", end="") + + # stream chat + if stream: + completed = None + last_stop = 0 + for res_data in request_stream_chat( + {"prompt": query, "history": history, "stream": True} + ): + if res_data: + text = res_data.get("response", failed_resp_test) + print(text[last_stop:], end="", flush=True) + last_stop = len(text) + completed = res_data + print("", end="\n\n", flush=True) + + if completed: + history = completed.get("history", history) + + else: + res_data = request_chat( + {"prompt": query, "history": history, "stream": False} + ) + if res_data: + history = res_data.get("history", history) + text = res_data.get("response", failed_resp_test) + for i in range(0, len(text), 8): + print(text[i : i + 8], end="", flush=True) + time.sleep(type_wait_time) + print("", end="\n\n", flush=True) + else: + print(failed_resp_test, end="\n\n") + + +if __name__ == "__main__": + main() diff --git a/stream_flask_api.py b/stream_flask_api.py new file mode 100644 index 0000000..ba5b60a --- /dev/null +++ b/stream_flask_api.py @@ -0,0 +1,152 @@ +import os + +from flask import Flask, Response, request, redirect +from transformers import AutoTokenizer, AutoModel +import json, time, datetime, hashlib, redis +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +GLM_SECRET_KEY = os.environ.get( + "GLM_SECRET_KEY", "721d95ac31da59fa022ec8c12f72f597" +) + +tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", trust_remote_code=True +) +model = ( + AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + .half() + .cuda() +) +model.eval() + +app = Flask(__name__) +mq = redis.StrictRedis() + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def wrap_answer(prompt, response, history, stop=True): + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": response, + "history": history, + "stop": stop, + "status": 200, + "time": time, + } + log = ( + "[" + + time + + "] " + + '", prompt:"' + + prompt + + '", response:"' + + repr(response) + + '"' + ) + print(log) + return answer + + +def chat(prompt, history=[], max_length=None, top_p=None, temperature=None): + response, history = model.chat( + tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95, + ) + + return wrap_answer(prompt, response, history) + + +def stream_chat( + prompt, history=[], max_length=None, top_p=None, temperature=None +): + global model, tokenizer + for response, history in model.stream_chat( + tokenizer, + prompt, + history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95, + ): + answer = wrap_answer(prompt, response, history, stop=False) + yield answer + + yield {"stop": True, "status": 200} + + +@app.post("/chat") +async def chat_api(): + secret = request.headers.get("Authorization") + if secret != GLM_SECRET_KEY: + return {"status": 403} + + global model, tokenizer + json_post_raw = request.json + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + stream = json_post_list.get("stream") + + if stream: + key = hashlib.md5( + (f"{secret}-{int(time.time() * 1000)}").encode() + ).hexdigest() + mq.set(f"prompt-{key}", json_post) + + return redirect(f"/stream/{key}", code=307) + else: + prompt = json_post_list.get("prompt") + history = json_post_list.get("history") + max_length = json_post_list.get("max_length") + top_p = json_post_list.get("top_p") + temperature = json_post_list.get("temperature") + + answer = chat(prompt, history, max_length, top_p, temperature) + torch_gc() + + return answer + + +@app.route("/stream/", methods=["GET", "POST"]) +def sse(key): + secret = request.headers.get("Authorization") + if secret != GLM_SECRET_KEY: + return {"status": 403} + + if key: + post_data = mq.get(f"prompt-{key}") + if post_data: + json_post_list = json.loads(post_data) + prompt = json_post_list.get("prompt") + history = json_post_list.get("history") + max_length = json_post_list.get("max_length") + top_p = json_post_list.get("top_p") + temperature = json_post_list.get("temperature") + + def generate(): + for answer in stream_chat( + prompt, history, max_length, top_p, temperature + ): + yield "\n" + json.dumps(answer) + "\n" + + return Response(generate(), mimetype="text/event-stream") + + return {"status": 200} + + +if __name__ == "__main__": + # pass + app.run(threaded=True, host="0.0.0.0", port=8888)