pull/679/merge
littlestone0806 2023-04-18 08:27:07 +00:00 committed by GitHub
commit ee0b36f527
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 11 deletions

48
api.py
View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
from sse_starlette.sse import EventSourceResponse
DEVICE = "cuda"
DEVICE_ID = "0"
@ -18,17 +19,23 @@ def torch_gc():
app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature):
for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p,
temperature=temperature):
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
yield json.dumps({
'response': response,
'history': history,
'status': 200,
'time': time
})
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
return torch_gc()
def predict(tokenizer, prompt, history, max_length, top_p, temperature):
response, history = model.chat(tokenizer,
prompt,
history=history,
@ -48,6 +55,25 @@ async def create_item(request: Request):
torch_gc()
return answer
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
stream = json_post_list.get('stream')
if stream:
res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature)
return EventSourceResponse(res)
else:
answer = predict(tokenizer, prompt, history, max_length, top_p, temperature)
return answer
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)