update api.py

pull/235/head
littlepanda0716 2023-03-25 18:59:11 +08:00
parent 963d5645ef
commit 023c46a317
1 changed files with 25 additions and 3 deletions

28
api.py
View File

@ -1,6 +1,19 @@
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI() app = FastAPI()
@ -13,7 +26,15 @@ async def create_item(request: Request):
json_post_list = json.loads(json_post) json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt') prompt = json_post_list.get('prompt')
history = json_post_list.get('history') history = json_post_list.get('history')
response, history = model.chat(tokenizer, prompt, history=history) max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -24,12 +45,13 @@ async def create_item(request: Request):
} }
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log) print(log)
torch_gc()
return answer return answer
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm_6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()