diff --git a/api.py b/api.py index 693c70a..cc5c4b2 100644 --- a/api.py +++ b/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel import uvicorn, json, datetime import torch +import asyncio DEVICE = "cuda" DEVICE_ID = "0" @@ -29,24 +30,28 @@ async def create_item(request: Request): max_length = json_post_list.get('max_length') top_p = json_post_list.get('top_p') temperature = json_post_list.get('temperature') - response, history = model.chat(tokenizer, - prompt, - history=history, - max_length=max_length if max_length else 2048, - top_p=top_p if top_p else 0.7, - temperature=temperature if temperature else 0.95) - now = datetime.datetime.now() - time = now.strftime("%Y-%m-%d %H:%M:%S") - answer = { - "response": response, - "history": history, - "status": 200, - "time": time - } - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) - torch_gc() - return answer + + def _sync_chat(history): + response, history = model.chat(tokenizer, + prompt, + history=history, + max_length=max_length if max_length else 2048, + top_p=top_p if top_p else 0.7, + temperature=temperature if temperature else 0.95) + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": response, + "history": history, + "status": 200, + "time": time + } + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + torch_gc() + return answer + + return await asyncio.to_thread(_sync_chat, history=history) if __name__ == '__main__':