ChatGLM2-6B/openai_api.py

358 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from transformers import AutoTokenizer, AutoModel
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Union
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
'''
用户数据(模拟数据库用户表);用于我们稍后验证。
账号johndoe 密码secret
为了数据安全我们利用PassLib对入库的用户密码进行加密处理推荐的加密算法是"Bcrypt"
其中,我们主要使用下面方法:
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
pwd_context.hash(password) # 获取对密码进行加密的密文
pwd_context.verify(password, hashed_password) # 对密码进行校验
获取api-key 方法见:
官方文档: https://fastapi.tiangolo.com/zh/tutorial/security/oauth2-jwt/#_5
Postman: https://blog.csdn.net/Disany/article/details/109365066
'''
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "johndoe@example.com",
# 得到这串加密后的密文,请运行:
# pwd_context.hash(password)
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
"disabled": False,
}
}
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Union[str, None] = None
class User(BaseModel):
username: str
email: Union[str, None] = None
full_name: Union[str, None] = None
disabled: Union[bool, None] = None
class UserInDB(User):
hashed_password: str
# Context是上下文,CryptContext是密码上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
'''
OAuth2PasswordBearer是接收URL作为参数的一个类客户端会向该URL发送username和password参数然后得到一个token值。
OAuth2PasswordBearer并不会创建相应的URL路径操作只是指明了客户端用来获取token的目标URL。
当请求到来的时候FastAPI会检查请求的Authorization头信息如果没有找到Authorization头信息或者头信息的内容不是Bearer token它会返回401状态码(UNAUTHORIZED)。
'''
# oauth2_scheme是令牌对象token: str = Depends(oauth2_scheme)后就是之前加密的令牌
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
# verify_password验证密码
# plain_password普通密码, hashed_password哈希密码
# 返回True和False
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# 获取哈希密码;普通密码进去,对应的哈希密码出来。
def get_password_hash(password):
return pwd_context.hash(password)
# 模拟从数据库读取用户信息
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
# 验证用户
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
# 创建访问令牌token
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
'''数据请求验证
用户拿到token信息后必须在后续请求中头信息的Authorization带有Bearer token才能访问其他数据接口。
下面添加一个校验函数对请求的合法性进行校验读取token内容解析并进行验证验证token通过后获取接口响应数据
'''
# 获取当前用户
# 通过oauth2_scheme拿到用户请求头文件里的token
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# jwt 解码
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# 通常在jwt 解码会进行验证抛出各种异常PyJWTError如令牌过期等
# 获取生成token时候我们放进去的username信息
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
# 获取该用户信息
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
# 获取当前激活用户通过数据库信息及相关条件对用户有效性进行过滤如该用户存在密码正确token验证通过但数据库字段显示该用户被封号或欠费了非激活用户就这此处触发异常结束访问。
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
''' ---- 登录验证获取token的接口 --------
用户发送post请求获取token,后端验证该用户是否存在密码是否正确。如果验证通过会生成token给到用户。
'''
# name = johndoe password = secret
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
# 1、验证用户
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# 2、access_token_expires访问令牌过期
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
# 3、create_access_token创建访问令牌
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/v1/models", response_model=ModelList)
# 如果不需要apikey功能可将括号中传参的 current_user 变量删除既可允许任意访问无需api-key身份认证。如下示例
# async def list_models():
async def list_models(current_user: User = Depends(get_current_active_user)):
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
# 如果不需要apikey功能可将括号中传参的 current_user 变量删除既可允许任意访问无需api-key身份认证。如下示例
# async def create_chat_completion(request: ChatCompletionRequest):
async def create_chat_completion(request: ChatCompletionRequest, current_user: User = Depends(get_current_active_user)):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])
if request.stream:
generate = predict(query, history, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)