mirror of https://github.com/THUDM/ChatGLM2-6B
Merge da66869750
into cb8e8b43c0
commit
eaf230d002
|
@ -12,12 +12,12 @@ from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
|
||||||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI): # collects GPU memory
|
async def lifespan(app: FastAPI): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -34,6 +34,7 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
|
@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
for i in range(0, len(prev_messages), 2):
|
for i in range(0, len(prev_messages), 2):
|
||||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i + 1].content])
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, request.model)
|
generate = predict(query, history, request.model)
|
||||||
|
@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
|
@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
|
||||||
yield '[DONE]'
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsData(BaseModel):
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: List[float]
|
||||||
|
index: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[EmbeddingsData]
|
||||||
|
model: str
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||||
|
async def create_chat_completion(request: EmbeddingsRequest):
|
||||||
|
global bert_model, bert_tokenizer
|
||||||
|
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = bert_model(**encoded_input)
|
||||||
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||||
|
return EmbeddingsResponse(
|
||||||
|
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
|
||||||
|
model=request.model,
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_pooling(model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0]
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||||
|
@ -173,5 +210,8 @@ if __name__ == "__main__":
|
||||||
# from utils import load_model_on_gpus
|
# from utils import load_model_on_gpus
|
||||||
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
bert_name = 'GanymedeNil/text2vec-large-chinese'
|
||||||
|
bert_model = BertModel.from_pretrained(bert_name)
|
||||||
|
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue