pull/547/merge
Whitroom 2024-09-30 04:17:07 -04:00 committed by GitHub
commit 5b50a47a51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 34 deletions

78
api.py
View File

@ -1,12 +1,31 @@
from fastapi import FastAPI, Request import datetime
from transformers import AutoTokenizer, AutoModel from contextlib import asynccontextmanager
import uvicorn, json, datetime
import torch import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
DEVICE = "cuda" DEVICE = "cuda"
DEVICE_ID = "0" DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
models['chat'] = AutoModel.from_pretrained(
"THUDM/chatglm-6b",
trust_remote_code=True).half().cuda()
models['chat'].eval()
models['tokenizer'] = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b",
trust_remote_code=True)
yield
for model in models.values():
del model
torch_gc()
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -14,43 +33,34 @@ def torch_gc():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app = FastAPI() class Item(BaseModel):
prompt: str = "你好"
history: list[tuple[str, str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95
class Answer(BaseModel):
response: str
history: list[tuple[str, str]]
status: int
time: str
@app.post("/") @app.post("/")
async def create_item(request: Request): async def create_item(item: Item):
global model, tokenizer response, history = models['chat'].chat(
json_post_raw = await request.json() models['tokenizer'],
json_post = json.dumps(json_post_raw) item.prompt,
json_post_list = json.loads(json_post) history=item.history,
prompt = json_post_list.get('prompt') max_length=item.max_length,
history = json_post_list.get('history') top_p=item.top_p,
max_length = json_post_list.get('max_length') temperature=item.temperature)
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now() now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S") time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = { print(f"[{time}] prompt: '{item.prompt}', response: '{response}'")
"response": response, return Answer(response=response, history=history, status=200, time=time)
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

28
api_test.py Normal file
View File

@ -0,0 +1,28 @@
import unittest
from httpx import AsyncClient
class TestGenerateChat(unittest.IsolatedAsyncioTestCase):
"""
测试生成聊天内容
1. 先启动服务
```bash
python api.py
```
2. 运行测试
```bash
python -m unittest api_test.py
```
"""
async def test_generate_chat(self):
async with AsyncClient() as ac:
response = await ac.post(
"http://localhost:8000/",
json={
"prompt": "你好",
"history": [],
"max_length": 2048,
"top_p": 0.7,
"temperature": 0.95
})
self.assertEqual(response.status_code, 200)
print(response.json())