From aac7f7bd1cdf41c0e32e8065874af6a99be28443 Mon Sep 17 00:00:00 2001 From: Zyx-A Date: Tue, 12 Sep 2023 16:33:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20api-key=20=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=8C=E5=B9=B6=E4=BD=BF=E7=94=A8=E4=BC=AA=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=9B=E8=A1=8C=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- openai_api.py | 202 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 191 insertions(+), 11 deletions(-) diff --git a/openai_api.py b/openai_api.py index 7225562..3a5f588 100644 --- a/openai_api.py +++ b/openai_api.py @@ -7,17 +7,89 @@ import time import torch import uvicorn -from pydantic import BaseModel, Field +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +from fastapi import Depends, FastAPI, HTTPException, status from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager -from typing import Any, Dict, List, Literal, Optional, Union -from transformers import AutoTokenizer, AutoModel +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel +from pydantic import BaseModel, Field from sse_starlette.sse import ServerSentEvent, EventSourceResponse +from transformers import AutoTokenizer, AutoModel +from typing import Any, Dict, List, Literal, Optional, Union +from typing import Union + +# to get a string like this run: +# openssl rand -hex 32 +SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +''' +用户数据(模拟数据库用户表);用于我们稍后验证。 +账号:johndoe 密码:secret + +为了数据安全,我们利用PassLib对入库的用户密码进行加密处理,推荐的加密算法是"Bcrypt" +其中,我们主要使用下面方法: +from passlib.context import CryptContext +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +pwd_context.hash(password) # 获取对密码进行加密的密文 +pwd_context.verify(password, hashed_password) # 对密码进行校验 + +获取api-key 方法见: +官方文档: https://fastapi.tiangolo.com/zh/tutorial/security/oauth2-jwt/#_5 +Postman: https://blog.csdn.net/Disany/article/details/109365066 +''' +fake_users_db = { + "johndoe": { + "username": "johndoe", + "full_name": "John Doe", + "email": "johndoe@example.com", + # 得到这串加密后的密文,请运行: + # pwd_context.hash(password) + "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", + "disabled": False, + } +} + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: Union[str, None] = None + + +class User(BaseModel): + username: str + email: Union[str, None] = None + full_name: Union[str, None] = None + disabled: Union[bool, None] = None + + +class UserInDB(User): + hashed_password: str + + +# Context是上下文,CryptContext是密码上下文 +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +''' +OAuth2PasswordBearer是接收URL作为参数的一个类:客户端会向该URL发送username和password参数,然后得到一个token值。 +OAuth2PasswordBearer并不会创建相应的URL路径操作,只是指明了客户端用来获取token的目标URL。 +当请求到来的时候,FastAPI会检查请求的Authorization头信息,如果没有找到Authorization头信息,或者头信息的内容不是Bearer token,它会返回401状态码(UNAUTHORIZED)。 +''' +# oauth2_scheme是令牌对象,token: str = Depends(oauth2_scheme)后就是之前加密的令牌 +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -34,6 +106,7 @@ app.add_middleware( allow_headers=["*"], ) + class ModelCard(BaseModel): id: str object: str = "model" @@ -87,15 +160,124 @@ class ChatCompletionResponse(BaseModel): created: Optional[int] = Field(default_factory=lambda: int(time.time())) +# verify_password验证密码 +# plain_password普通密码, hashed_password哈希密码 +# 返回True和False +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +# 获取哈希密码;普通密码进去,对应的哈希密码出来。 +def get_password_hash(password): + return pwd_context.hash(password) + + +# 模拟从数据库读取用户信息 +def get_user(db, username: str): + if username in db: + user_dict = db[username] + return UserInDB(**user_dict) + + +# 验证用户 +def authenticate_user(fake_db, username: str, password: str): + user = get_user(fake_db, username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +# 创建访问令牌(token) +def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +'''数据请求验证 +用户拿到token信息后,必须在后续请求中,头信息的Authorization带有Bearer token,才能访问其他数据接口。 +下面添加一个校验函数,对请求的合法性进行校验,读取token内容解析并进行验证,验证token通过后,获取接口响应数据 +''' + + +# 获取当前用户 +# 通过oauth2_scheme,拿到用户请求头文件里的token +async def get_current_user(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + # jwt 解码 + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + # 通常在jwt 解码会进行验证抛出各种异常PyJWTError,如令牌过期等; + # 获取生成token时候,我们放进去的username信息 + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except JWTError: + raise credentials_exception + # 获取该用户信息 + user = get_user(fake_users_db, username=token_data.username) + if user is None: + raise credentials_exception + return user + + +# 获取当前激活用户,通过数据库信息及相关条件对用户有效性进行过滤;如该用户存在,密码正确,token验证通过,但数据库字段显示该用户被封号或欠费了(非激活用户),就这此处触发异常,结束访问。 +async def get_current_active_user(current_user: User = Depends(get_current_user)): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +''' ---- 登录验证,获取token的接口 -------- + 用户发送post请求获取token,后端验证该用户是否存在,密码是否正确。如果验证通过,会生成‘token’给到用户。 +''' + + +# name = johndoe password = secret +@app.post("/token", response_model=Token) +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + # 1、验证用户 + user = authenticate_user(fake_users_db, form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + # 2、access_token_expires访问令牌过期 + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + # 3、create_access_token创建访问令牌 + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + return {"access_token": access_token, "token_type": "bearer"} + + @app.get("/v1/models", response_model=ModelList) -async def list_models(): +# 如果不需要apikey功能,可将括号中传参的 current_user 变量删除,既可允许任意访问,无需api-key身份认证。如下示例: +# async def list_models(): +async def list_models(current_user: User = Depends(get_current_active_user)): global model_args model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def create_chat_completion(request: ChatCompletionRequest): +# 如果不需要apikey功能,可将括号中传参的 current_user 变量删除,既可允许任意访问,无需api-key身份认证。如下示例: +# async def create_chat_completion(request: ChatCompletionRequest): +async def create_chat_completion(request: ChatCompletionRequest, current_user: User = Depends(get_current_active_user)): global model, tokenizer if request.messages[-1].role != "user": @@ -109,8 +291,8 @@ async def create_chat_completion(request: ChatCompletionRequest): history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": - history.append([prev_messages[i].content, prev_messages[i+1].content]) + if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i + 1].content]) if request.stream: generate = predict(query, history, request.model) @@ -154,7 +336,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) - choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), @@ -165,7 +346,6 @@ async def predict(query: str, history: List[List[str]], model_id: str): yield '[DONE]' - if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()