perf: 优化 asgi 的位置

pull/11072/head
ibuler 2023-07-24 23:20:05 +08:00
parent 9195d4c43d
commit 089d769eb0
4 changed files with 64 additions and 55 deletions

View File

@ -1,21 +1,8 @@
import os
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.core.asgi import get_asgi_application
from .middleware import WsSignatureAuthMiddleware
from .routing import urlpatterns
import django
from channels.routing import get_default_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
application = ProtocolTypeRouter({
# Django's ASGI application to handle traditional HTTP requests
"http": get_asgi_application(),
# WebSocket chat handler
"websocket": WsSignatureAuthMiddleware(
AuthMiddlewareStack(
URLRouter(urlpatterns)
)
),
})
django.setup()
application = get_default_application()

View File

@ -6,16 +6,12 @@ import re
import time
import pytz
from channels.db import database_sync_to_async
from django.conf import settings
from django.core.exceptions import MiddlewareNotUsed
from django.core.handlers.asgi import ASGIRequest
from django.http.response import HttpResponseForbidden
from django.shortcuts import HttpResponse
from django.utils import timezone
from authentication.backends.drf import (SignatureAuthentication,
AccessTokenAuthentication)
from .utils import set_current_request
@ -146,35 +142,3 @@ class EndMiddleware:
response = self.get_response(request)
request._e_time_end = time.time()
return response
@database_sync_to_async
def get_signature_user(scope):
headers = dict(scope["headers"])
if not headers.get(b'authorization'):
return
if scope['type'] == 'websocket':
scope['method'] = 'GET'
try:
# 因为 ws 使用的是 scope所以需要转换成 request 对象,用于认证校验
request = ASGIRequest(scope, None)
backends = [SignatureAuthentication(),
AccessTokenAuthentication()]
for backend in backends:
user, _ = backend.authenticate(request)
if user:
return user
except Exception as e:
print(e)
return None
class WsSignatureAuthMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
user = await get_signature_user(scope)
if user:
scope['user'] = user
return await self.app(scope, receive, send)

View File

@ -1,9 +1,66 @@
from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from channels.routing import ProtocolTypeRouter, URLRouter
from django.core.asgi import get_asgi_application
from django.core.handlers.asgi import ASGIRequest
from authentication.backends.drf import (
SignatureAuthentication,
AccessTokenAuthentication
)
from notifications.urls.ws_urls import urlpatterns as notifications_urlpatterns
from ops.urls.ws_urls import urlpatterns as ops_urlpatterns
from settings.urls.ws_urls import urlpatterns as setting_urlpatterns
from terminal.urls.ws_urls import urlpatterns as terminal_urlpatterns
__all__ = ['urlpatterns']
urlpatterns = ops_urlpatterns + \
notifications_urlpatterns + \
setting_urlpatterns + \
terminal_urlpatterns
@database_sync_to_async
def get_signature_user(scope):
headers = dict(scope["headers"])
if not headers.get(b'authorization'):
return
if scope['type'] == 'websocket':
scope['method'] = 'GET'
try:
# 因为 ws 使用的是 scope所以需要转换成 request 对象,用于认证校验
request = ASGIRequest(scope, None)
backends = [SignatureAuthentication(),
AccessTokenAuthentication()]
for backend in backends:
user, _ = backend.authenticate(request)
if user:
return user
except Exception as e:
print(e)
return None
class WsSignatureAuthMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
user = await get_signature_user(scope)
if user:
scope['user'] = user
return await self.app(scope, receive, send)
application = ProtocolTypeRouter({
# Django's ASGI application to handle traditional HTTP requests
"http": get_asgi_application(),
# WebSocket chat handler
"websocket": WsSignatureAuthMiddleware(
AuthMiddlewareStack(
URLRouter(urlpatterns)
)
),
})

View File

@ -116,7 +116,8 @@ else:
host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, db=CONFIG.REDIS_DB_WS
)
REDIS_LAYERS_SSL_PARAMS.pop('ssl', None)
REDIS_LAYERS_HOST['address'] = '{}?{}'.format(REDIS_LAYERS_ADDRESS, urlencode(REDIS_LAYERS_SSL_PARAMS))
REDIS_LAYERS_HOST['address'] = '{}?{}'.format(REDIS_LAYERS_ADDRESS,
urlencode(REDIS_LAYERS_SSL_PARAMS))
CHANNEL_LAYERS = {
'default': {
@ -127,7 +128,7 @@ CHANNEL_LAYERS = {
},
}
ASGI_APPLICATION = 'jumpserver.asgi.application'
ASGI_APPLICATION = 'jumpserver.routing.application'
# Dump all celery log to here
CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'celery')