pull/386/merge
chenshimeng 2024-07-29 19:14:57 +08:00 committed by GitHub
commit eaf230d002
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 46 additions and 6 deletions

View File

@ -12,12 +12,12 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -34,6 +34,7 @@ app.add_middleware(
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])
if request.stream:
generate = predict(query, history, request.model)
@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
yield '[DONE]'
class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]]
class EmbeddingsData(BaseModel):
object: str = "embedding"
embedding: List[float]
index: int = 0
class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[EmbeddingsData]
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def create_chat_completion(request: EmbeddingsRequest):
global bert_model, bert_tokenizer
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = bert_model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
model=request.model,
usage={},
)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
@ -173,5 +210,8 @@ if __name__ == "__main__":
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model.eval()
bert_name = 'GanymedeNil/text2vec-large-chinese'
bert_model = BertModel.from_pretrained(bert_name)
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)