mirror of https://github.com/THUDM/ChatGLM2-6B
Add embeddings api
parent
41f5c436fb
commit
da66869750
|
@ -12,7 +12,7 @@ from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
|
||||||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
|
@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
|
@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
yield '[DONE]'
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsData(BaseModel):
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: List[float]
|
||||||
|
index: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[EmbeddingsData]
|
||||||
|
model: str
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||||
|
async def create_chat_completion(request: EmbeddingsRequest):
|
||||||
|
global bert_model, bert_tokenizer
|
||||||
|
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = bert_model(**encoded_input)
|
||||||
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||||
|
return EmbeddingsResponse(
|
||||||
|
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
|
||||||
|
model=request.model,
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_pooling(model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0]
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||||
|
@ -173,5 +210,8 @@ if __name__ == "__main__":
|
||||||
# from utils import load_model_on_gpus
|
# from utils import load_model_on_gpus
|
||||||
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
bert_name = 'GanymedeNil/text2vec-large-chinese'
|
||||||
|
bert_model = BertModel.from_pretrained(bert_name)
|
||||||
|
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue