mirror of https://github.com/THUDM/ChatGLM2-6B
Add embeddings api
parent
41f5c436fb
commit
da66869750
|
@ -12,12 +12,12 @@ from fastapi import FastAPI, HTTPException
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
|
||||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -34,6 +34,7 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
|
@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
|
||||
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, request.model)
|
||||
|
@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
|
@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
|||
yield '[DONE]'
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: Union[str, List[str]]
|
||||
|
||||
|
||||
class EmbeddingsData(BaseModel):
|
||||
object: str = "embedding"
|
||||
embedding: List[float]
|
||||
index: int = 0
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[EmbeddingsData]
|
||||
model: str
|
||||
usage: dict
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||
async def create_chat_completion(request: EmbeddingsRequest):
|
||||
global bert_model, bert_tokenizer
|
||||
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
|
||||
with torch.no_grad():
|
||||
model_output = bert_model(**encoded_input)
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
return EmbeddingsResponse(
|
||||
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
|
||||
model=request.model,
|
||||
usage={},
|
||||
)
|
||||
|
||||
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0]
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||
|
@ -173,5 +210,8 @@ if __name__ == "__main__":
|
|||
# from utils import load_model_on_gpus
|
||||
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
||||
model.eval()
|
||||
|
||||
bert_name = 'GanymedeNil/text2vec-large-chinese'
|
||||
bert_model = BertModel.from_pretrained(bert_name)
|
||||
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue