Add history in API

pull/220/head
duzx16 2023-03-23 21:42:43 +08:00
parent 4eca73636e
commit b0c2b47f5e
1 changed files with 19 additions and 17 deletions

36
API.py
View File

@ -1,33 +1,35 @@
from typing import Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, time, datetime, os, platform import uvicorn, json, datetime
app = FastAPI() app = FastAPI()
@app.post("/") @app.post("/")
async def create_item(request: Request): async def create_item(request: Request):
global history, model, tokenizer global model, tokenizer
jsonPostRaw = await request.json() json_post_raw = await request.json()
jsonPost = json.dumps(jsonPostRaw) json_post = json.dumps(json_post_raw)
jsonPostList = json.loads(jsonPost) json_post_list = json.loads(json_post)
prompt = jsonPostList.get('prompt') prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
response, history = model.chat(tokenizer, prompt, history=history) response, history = model.chat(tokenizer, prompt, history=history)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
"response":response, "response": response,
"status":200, "history": history,
"time":time "status": 200,
"time": time
} }
log = "["+time+"] "+'device:"'+jsonPostList.get('device')+'", prompt:"'+prompt+'", response:"'+repr(response)+'"' log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log) print(log)
return answer return answer
if __name__ == '__main__':
uvicorn.run('API:app',host='0.0.0.0',port=8000,workers=1)
history = [] if __name__ == '__main__':
uvicorn.run('API:app', host='0.0.0.0', port=8000, workers=1)
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda() model = AutoModel.from_pretrained("THUDM/chatglm_6b", trust_remote_code=True).half().cuda()
model = model.eval() model.eval()