diff --git a/api.py b/api.py index 693c70a..efc28bc 100644 --- a/api.py +++ b/api.py @@ -1,7 +1,11 @@ -from fastapi import FastAPI, Request -from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime +import datetime +import json + import torch +import uvicorn +from fastapi import FastAPI, Request +from pydantic import BaseModel +from transformers import AutoModel, AutoTokenizer DEVICE = "cuda" DEVICE_ID = "0" @@ -17,6 +21,18 @@ def torch_gc(): app = FastAPI() +class Item(BaseModel): + prompt: str + history: list[tuple[str, str]] = [[]] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 + +class Answer(BaseModel): + response: str + history: list[tuple[str, str]] + status: int + time: str @app.post("/") async def create_item(request: Request):