mirror of https://github.com/jumpserver/jumpserver
168 lines
5.3 KiB
Python
168 lines
5.3 KiB
Python
# ~*~ coding: utf-8 ~*~
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
import time
|
||
|
||
import pytz
|
||
from channels.db import database_sync_to_async
|
||
from django.conf import settings
|
||
from django.core.exceptions import MiddlewareNotUsed
|
||
from django.core.handlers.asgi import ASGIRequest
|
||
from django.http.response import HttpResponseForbidden
|
||
from django.shortcuts import HttpResponse
|
||
from django.utils import timezone
|
||
|
||
from authentication.backends.drf import (SignatureAuthentication,
|
||
AccessTokenAuthentication)
|
||
from .utils import set_current_request
|
||
|
||
|
||
class TimezoneMiddleware:
|
||
def __init__(self, get_response):
|
||
self.get_response = get_response
|
||
|
||
def __call__(self, request):
|
||
tzname = request.META.get('HTTP_X_TZ')
|
||
if not tzname or tzname == 'undefined':
|
||
return self.get_response(request)
|
||
try:
|
||
tz = pytz.timezone(tzname)
|
||
timezone.activate(tz)
|
||
except pytz.UnknownTimeZoneError:
|
||
pass
|
||
response = self.get_response(request)
|
||
return response
|
||
|
||
|
||
class DemoMiddleware:
|
||
DEMO_MODE_ENABLED = os.environ.get("DEMO_MODE", "") in ("1", "ok", "True")
|
||
SAFE_URL_PATTERN = re.compile(
|
||
r'^/users/login|'
|
||
r'^/api/terminal/v1/.*|'
|
||
r'^/api/terminal/.*|'
|
||
r'^/api/users/v1/auth/|'
|
||
r'^/api/users/v1/profile/'
|
||
)
|
||
SAFE_METHOD = ("GET", "HEAD")
|
||
|
||
def __init__(self, get_response):
|
||
self.get_response = get_response
|
||
|
||
if self.DEMO_MODE_ENABLED:
|
||
print("Demo mode enabled, reject unsafe method and url")
|
||
raise MiddlewareNotUsed
|
||
|
||
def __call__(self, request):
|
||
if self.DEMO_MODE_ENABLED and request.method not in self.SAFE_METHOD \
|
||
and not self.SAFE_URL_PATTERN.match(request.path):
|
||
return HttpResponse("Demo mode, only safe request accepted", status=403)
|
||
else:
|
||
response = self.get_response(request)
|
||
return response
|
||
|
||
|
||
class RequestMiddleware:
|
||
def __init__(self, get_response):
|
||
self.get_response = get_response
|
||
|
||
def __call__(self, request):
|
||
set_current_request(request)
|
||
response = self.get_response(request)
|
||
is_request_api = request.path.startswith('/api')
|
||
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
|
||
not is_request_api:
|
||
age = request.session.get_expiry_age()
|
||
request.session.set_expiry(age)
|
||
return response
|
||
|
||
|
||
class RefererCheckMiddleware:
|
||
def __init__(self, get_response):
|
||
if not settings.REFERER_CHECK_ENABLED:
|
||
raise MiddlewareNotUsed
|
||
self.get_response = get_response
|
||
self.http_pattern = re.compile('https?://')
|
||
|
||
def check_referer(self, request):
|
||
referer = request.META.get('HTTP_REFERER', '')
|
||
referer = self.http_pattern.sub('', referer)
|
||
if not referer:
|
||
return True
|
||
remote_host = request.get_host()
|
||
return referer.startswith(remote_host)
|
||
|
||
def __call__(self, request):
|
||
match = self.check_referer(request)
|
||
if not match:
|
||
return HttpResponseForbidden('CSRF CHECK ERROR')
|
||
response = self.get_response(request)
|
||
return response
|
||
|
||
|
||
class StartMiddleware:
|
||
def __init__(self, get_response):
|
||
self.get_response = get_response
|
||
if not settings.DEBUG_DEV:
|
||
raise MiddlewareNotUsed
|
||
|
||
def __call__(self, request):
|
||
request._s_time_start = time.time()
|
||
response = self.get_response(request)
|
||
request._s_time_end = time.time()
|
||
if request.path == '/api/health/':
|
||
data = response.data
|
||
data['pre_middleware_time'] = request._e_time_start - request._s_time_start
|
||
data['api_time'] = request._e_time_end - request._e_time_start
|
||
data['post_middleware_time'] = request._s_time_end - request._e_time_end
|
||
response.content = json.dumps(data)
|
||
response.headers['Content-Length'] = str(len(response.content))
|
||
return response
|
||
return response
|
||
|
||
|
||
class EndMiddleware:
|
||
def __init__(self, get_response):
|
||
self.get_response = get_response
|
||
if not settings.DEBUG_DEV:
|
||
raise MiddlewareNotUsed
|
||
|
||
def __call__(self, request):
|
||
request._e_time_start = time.time()
|
||
response = self.get_response(request)
|
||
request._e_time_end = time.time()
|
||
return response
|
||
|
||
|
||
@database_sync_to_async
|
||
def get_signature_user(scope):
|
||
headers = dict(scope["headers"])
|
||
if not headers.get(b'authorization'):
|
||
return
|
||
if scope['type'] == 'websocket':
|
||
scope['method'] = 'GET'
|
||
try:
|
||
# 因为 ws 使用的是 scope,所以需要转换成 request 对象,用于认证校验
|
||
request = ASGIRequest(scope, None)
|
||
backends = [SignatureAuthentication(),
|
||
AccessTokenAuthentication()]
|
||
for backend in backends:
|
||
user, _ = backend.authenticate(request)
|
||
if user:
|
||
return user
|
||
except Exception as e:
|
||
print(e)
|
||
return None
|
||
|
||
|
||
class WsSignatureAuthMiddleware:
|
||
def __init__(self, app):
|
||
self.app = app
|
||
|
||
async def __call__(self, scope, receive, send):
|
||
user = await get_signature_user(scope)
|
||
if user:
|
||
scope['user'] = user
|
||
return await self.app(scope, receive, send)
|