jumpserver/apps/jumpserver/middleware.py

168 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# ~*~ coding: utf-8 ~*~
import json
import os
import re
import time
import pytz
from channels.db import database_sync_to_async
from django.conf import settings
from django.core.exceptions import MiddlewareNotUsed
from django.core.handlers.asgi import ASGIRequest
from django.http.response import HttpResponseForbidden
from django.shortcuts import HttpResponse
from django.utils import timezone
from authentication.backends.drf import (SignatureAuthentication,
AccessTokenAuthentication)
from .utils import set_current_request
class TimezoneMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
tzname = request.META.get('HTTP_X_TZ')
if not tzname or tzname == 'undefined':
return self.get_response(request)
try:
tz = pytz.timezone(tzname)
timezone.activate(tz)
except pytz.UnknownTimeZoneError:
pass
response = self.get_response(request)
return response
class DemoMiddleware:
DEMO_MODE_ENABLED = os.environ.get("DEMO_MODE", "") in ("1", "ok", "True")
SAFE_URL_PATTERN = re.compile(
r'^/users/login|'
r'^/api/terminal/v1/.*|'
r'^/api/terminal/.*|'
r'^/api/users/v1/auth/|'
r'^/api/users/v1/profile/'
)
SAFE_METHOD = ("GET", "HEAD")
def __init__(self, get_response):
self.get_response = get_response
if self.DEMO_MODE_ENABLED:
print("Demo mode enabled, reject unsafe method and url")
raise MiddlewareNotUsed
def __call__(self, request):
if self.DEMO_MODE_ENABLED and request.method not in self.SAFE_METHOD \
and not self.SAFE_URL_PATTERN.match(request.path):
return HttpResponse("Demo mode, only safe request accepted", status=403)
else:
response = self.get_response(request)
return response
class RequestMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
set_current_request(request)
response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response
class RefererCheckMiddleware:
def __init__(self, get_response):
if not settings.REFERER_CHECK_ENABLED:
raise MiddlewareNotUsed
self.get_response = get_response
self.http_pattern = re.compile('https?://')
def check_referer(self, request):
referer = request.META.get('HTTP_REFERER', '')
referer = self.http_pattern.sub('', referer)
if not referer:
return True
remote_host = request.get_host()
return referer.startswith(remote_host)
def __call__(self, request):
match = self.check_referer(request)
if not match:
return HttpResponseForbidden('CSRF CHECK ERROR')
response = self.get_response(request)
return response
class StartMiddleware:
def __init__(self, get_response):
self.get_response = get_response
if not settings.DEBUG_DEV:
raise MiddlewareNotUsed
def __call__(self, request):
request._s_time_start = time.time()
response = self.get_response(request)
request._s_time_end = time.time()
if request.path == '/api/health/':
data = response.data
data['pre_middleware_time'] = request._e_time_start - request._s_time_start
data['api_time'] = request._e_time_end - request._e_time_start
data['post_middleware_time'] = request._s_time_end - request._e_time_end
response.content = json.dumps(data)
response.headers['Content-Length'] = str(len(response.content))
return response
return response
class EndMiddleware:
def __init__(self, get_response):
self.get_response = get_response
if not settings.DEBUG_DEV:
raise MiddlewareNotUsed
def __call__(self, request):
request._e_time_start = time.time()
response = self.get_response(request)
request._e_time_end = time.time()
return response
@database_sync_to_async
def get_signature_user(scope):
headers = dict(scope["headers"])
if not headers.get(b'authorization'):
return
if scope['type'] == 'websocket':
scope['method'] = 'GET'
try:
# 因为 ws 使用的是 scope所以需要转换成 request 对象,用于认证校验
request = ASGIRequest(scope, None)
backends = [SignatureAuthentication(),
AccessTokenAuthentication()]
for backend in backends:
user, _ = backend.authenticate(request)
if user:
return user
except Exception as e:
print(e)
return None
class WsSignatureAuthMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
user = await get_signature_user(scope)
if user:
scope['user'] = user
return await self.app(scope, receive, send)