Merge pull request #235 from imClumsyPanda/main

update api.py
dev_api
Zhengxiao Du 2023-03-28 19:28:56 +08:00 committed by GitHub
commit 3e9e02fb1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 3 deletions

28
api.py
View File

@ -1,6 +1,19 @@
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime import uvicorn, json, datetime
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI() app = FastAPI()
@ -13,7 +26,15 @@ async def create_item(request: Request):
json_post_list = json.loads(json_post) json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt') prompt = json_post_list.get('prompt')
history = json_post_list.get('history') history = json_post_list.get('history')
response, history = model.chat(tokenizer, prompt, history=history) max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { answer = {
@ -24,12 +45,13 @@ async def create_item(request: Request):
} }
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log) print(log)
torch_gc()
return answer return answer
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run('api:app', host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() model = AutoModel.from_pretrained("THUDM/chatglm_6b", trust_remote_code=True).half().cuda()
model.eval() model.eval()