pull/679/merge
littlestone0806 2023-04-18 08:27:07 +00:00 committed by GitHub
commit ee0b36f527
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 11 deletions

48
api.py
View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch import torch
from sse_starlette.sse import EventSourceResponse
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
@ -18,17 +19,23 @@ def torch_gc():
app = FastAPI() app = FastAPI()
@app.post("/") def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature):
async def create_item(request: Request): for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p,
global model, tokenizer temperature=temperature):
json_post_raw = await request.json() now = datetime.datetime.now()
json_post = json.dumps(json_post_raw) time = now.strftime("%Y-%m-%d %H:%M:%S")
json_post_list = json.loads(json_post) yield json.dumps({
prompt = json_post_list.get('prompt') 'response': response,
history = json_post_list.get('history') 'history': history,
max_length = json_post_list.get('max_length') 'status': 200,
top_p = json_post_list.get('top_p') 'time': time
temperature = json_post_list.get('temperature') })
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
return torch_gc()
def predict(tokenizer, prompt, history, max_length, top_p, temperature):
response, history = model.chat(tokenizer, response, history = model.chat(tokenizer,
prompt, prompt,
history=history, history=history,
@ -48,6 +55,25 @@ async def create_item(request: Request):
torch_gc() torch_gc()
return answer return answer
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
stream = json_post_list.get('stream')
if stream:
res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature)
return EventSourceResponse(res)
else:
answer = predict(tokenizer, prompt, history, max_length, top_p, temperature)
return answer
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)