diff --git a/openai_api.py b/openai_api.py index 7225562..488e72d 100644 --- a/openai_api.py +++ b/openai_api.py @@ -12,12 +12,12 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Any, Dict, List, Literal, Optional, Union -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer from sse_starlette.sse import ServerSentEvent, EventSourceResponse @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -34,6 +34,7 @@ app.add_middleware( allow_headers=["*"], ) + class ModelCard(BaseModel): id: str object: str = "model" @@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest): history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": - history.append([prev_messages[i].content, prev_messages[i+1].content]) + if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i + 1].content]) if request.stream: generate = predict(query, history, request.model) @@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) - choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), @@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str): yield '[DONE]' +class EmbeddingsRequest(BaseModel): + model: str + input: Union[str, List[str]] + + +class EmbeddingsData(BaseModel): + object: str = "embedding" + embedding: List[float] + index: int = 0 + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[EmbeddingsData] + model: str + usage: dict + + +@app.post("/v1/embeddings", response_model=EmbeddingsResponse) +async def create_chat_completion(request: EmbeddingsRequest): + global bert_model, bert_tokenizer + encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = bert_model(**encoded_input) + sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) + return EmbeddingsResponse( + data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)], + model=request.model, + usage={}, + ) + + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) @@ -173,5 +210,8 @@ if __name__ == "__main__": # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) model.eval() - + bert_name = 'GanymedeNil/text2vec-large-chinese' + bert_model = BertModel.from_pretrained(bert_name) + bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) +