From 24dada9d5aeaf7db100168300454f363e7208cf1 Mon Sep 17 00:00:00 2001 From: Whitroom <1062015905@qq.com> Date: Wed, 12 Apr 2023 14:45:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AF=B7=E6=B1=82=E4=B8=8E?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/api.py b/api.py index 693c70a..efc28bc 100644 --- a/api.py +++ b/api.py @@ -1,7 +1,11 @@ -from fastapi import FastAPI, Request -from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime +import datetime +import json + import torch +import uvicorn +from fastapi import FastAPI, Request +from pydantic import BaseModel +from transformers import AutoModel, AutoTokenizer DEVICE = "cuda" DEVICE_ID = "0" @@ -17,6 +21,18 @@ def torch_gc(): app = FastAPI() +class Item(BaseModel): + prompt: str + history: list[tuple[str, str]] = [[]] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 + +class Answer(BaseModel): + response: str + history: list[tuple[str, str]] + status: int + time: str @app.post("/") async def create_item(request: Request):