Add embeddings api

pull/386/head
陈诗萌 2023-07-26 16:24:30 +08:00
parent 41f5c436fb
commit da66869750
1 changed files with 46 additions and 6 deletions

View File

@ -12,12 +12,12 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -34,6 +34,7 @@ app.add_middleware(
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])
if request.stream:
generate = predict(query, history, request.model)
@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
yield '[DONE]'
class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]]
class EmbeddingsData(BaseModel):
object: str = "embedding"
embedding: List[float]
index: int = 0
class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[EmbeddingsData]
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def create_chat_completion(request: EmbeddingsRequest):
global bert_model, bert_tokenizer
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = bert_model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
model=request.model,
usage={},
)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
@ -173,5 +210,8 @@ if __name__ == "__main__":
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model.eval()
bert_name = 'GanymedeNil/text2vec-large-chinese'
bert_model = BertModel.from_pretrained(bert_name)
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)