pull/547/merge
Whitroom 2024-09-30 04:17:07 -04:00 committed by GitHub
commit 5b50a47a51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 34 deletions

78
api.py
View File

@ -1,12 +1,31 @@
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import datetime
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
models['chat'] = AutoModel.from_pretrained(
"THUDM/chatglm-6b",
trust_remote_code=True).half().cuda()
models['chat'].eval()
models['tokenizer'] = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b",
trust_remote_code=True)
yield
for model in models.values():
del model
torch_gc()
def torch_gc():
if torch.cuda.is_available():
@ -14,43 +33,34 @@ def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app = FastAPI()
class Item(BaseModel):
prompt: str = "你好"
history: list[tuple[str, str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
class Answer(BaseModel):
response: str
history: list[tuple[str, str]]
status: int
time: str
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
async def create_item(item: Item):
response, history = models['chat'].chat(
models['tokenizer'],
item.prompt,
history=item.history,
max_length=item.max_length,
top_p=item.top_p,
temperature=item.temperature)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
print(f"[{time}] prompt: '{item.prompt}', response: '{response}'")
return Answer(response=response, history=history, status=200, time=time)
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

28
api_test.py Normal file
View File

@ -0,0 +1,28 @@
import unittest
from httpx import AsyncClient
class TestGenerateChat(unittest.IsolatedAsyncioTestCase):
"""
测试生成聊天内容
1. 先启动服务
```bash
python api.py
```
2. 运行测试
```bash
python -m unittest api_test.py
```
"""
async def test_generate_chat(self):
async with AsyncClient() as ac:
response = await ac.post(
"http://localhost:8000/",
json={
"prompt": "你好",
"history": [],
"max_length": 2048,
"top_p": 0.7,
"temperature": 0.95
})
self.assertEqual(response.status_code, 200)
print(response.json())